mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-21 02:28:12 -05:00
Compare commits
87 Commits
feat/batch
...
fix/sde-ra
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
096c17465b | ||
|
|
d09dfc3e9b | ||
|
|
66f524cae7 | ||
|
|
9ba50130a1 | ||
|
|
d4cf2d2666 | ||
|
|
b8b589c150 | ||
|
|
d93900a8de | ||
|
|
dc96a3e79d | ||
|
|
c076f1397e | ||
|
|
2568aafc0b | ||
|
|
b6e369c745 | ||
|
|
ecabfc252b | ||
|
|
d162b78767 | ||
|
|
eb6c317f04 | ||
|
|
6d7223238f | ||
|
|
8607d124c5 | ||
|
|
23497bf759 | ||
|
|
3d93851dba | ||
|
|
9bacd77a79 | ||
|
|
1b158f62c4 | ||
|
|
6ad565d84c | ||
|
|
04229082d6 | ||
|
|
03c27412f7 | ||
|
|
f0613bb0ef | ||
|
|
0e9f92b868 | ||
|
|
7d0cc6ec3f | ||
|
|
2f8b928486 | ||
|
|
0d3c27f46c | ||
|
|
cff91f06d3 | ||
|
|
1d5d187ba1 | ||
|
|
1ac14a1e43 | ||
|
|
cfc3a20565 | ||
|
|
05ae4e283c | ||
|
|
f06fee4581 | ||
|
|
9091e19de8 | ||
|
|
0a0b7141af | ||
|
|
446fb4a438 | ||
|
|
ab5d938a1d | ||
|
|
9942af756a | ||
|
|
06742faca7 | ||
|
|
d2bddf7f91 | ||
|
|
bf94412d14 | ||
|
|
e080fd1e08 | ||
|
|
eeef1e08f8 | ||
|
|
b3b94b5a8d | ||
|
|
5c9787c145 | ||
|
|
cf72eba15c | ||
|
|
a6f9396a30 | ||
|
|
118d5b387b | ||
|
|
db545f8801 | ||
|
|
b0d72b15b3 | ||
|
|
4e0949fa55 | ||
|
|
f028342f5b | ||
|
|
7021467048 | ||
|
|
26ef5249b1 | ||
|
|
87424be95d | ||
|
|
366952f810 | ||
|
|
450e95de59 | ||
|
|
0ba8a0ea6c | ||
|
|
f4981f26d5 | ||
|
|
43d6312587 | ||
|
|
0d125bf3e4 | ||
|
|
921ccad04d | ||
|
|
05c9207e7b | ||
|
|
3fc789a7ee | ||
|
|
008362918e | ||
|
|
8fc75a71ee | ||
|
|
82d259f43b | ||
|
|
818c55cd53 | ||
|
|
0db1e97119 | ||
|
|
ed76250dba | ||
|
|
4d22cafdad | ||
|
|
8a4e5f73aa | ||
|
|
4599575e65 | ||
|
|
242d860a47 | ||
|
|
0c1a7e72d4 | ||
|
|
11a44b944d | ||
|
|
fd7b842419 | ||
|
|
7292d89108 | ||
|
|
437f45a97f | ||
|
|
13ef33ed64 | ||
|
|
86d8b46fca | ||
|
|
df53b62048 | ||
|
|
55d3f04476 | ||
|
|
72ebe2ce68 | ||
|
|
7cd8b2f207 | ||
|
|
35dd58e273 |
14
.github/workflows/style-checks.yml
vendored
14
.github/workflows/style-checks.yml
vendored
@@ -1,13 +1,14 @@
|
|||||||
name: Black # TODO: add isort and flake8 later
|
name: style checks
|
||||||
|
# just formatting for now
|
||||||
|
# TODO: add isort and flake8 later
|
||||||
|
|
||||||
on:
|
on:
|
||||||
pull_request: {}
|
pull_request:
|
||||||
push:
|
push:
|
||||||
branches: master
|
branches: main
|
||||||
tags: "*"
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test:
|
black:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v3
|
||||||
@@ -19,8 +20,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Install dependencies with pip
|
- name: Install dependencies with pip
|
||||||
run: |
|
run: |
|
||||||
pip install --upgrade pip wheel
|
pip install black
|
||||||
pip install .[test]
|
|
||||||
|
|
||||||
# - run: isort --check-only .
|
# - run: isort --check-only .
|
||||||
- run: black --check .
|
- run: black --check .
|
||||||
|
|||||||
50
.github/workflows/test-invoke-pip-skip.yml
vendored
50
.github/workflows/test-invoke-pip-skip.yml
vendored
@@ -1,50 +0,0 @@
|
|||||||
name: Test invoke.py pip
|
|
||||||
|
|
||||||
# This is a dummy stand-in for the actual tests
|
|
||||||
# we don't need to run python tests on non-Python changes
|
|
||||||
# But PRs require passing tests to be mergeable
|
|
||||||
|
|
||||||
on:
|
|
||||||
pull_request:
|
|
||||||
paths:
|
|
||||||
- '**'
|
|
||||||
- '!pyproject.toml'
|
|
||||||
- '!invokeai/**'
|
|
||||||
- '!tests/**'
|
|
||||||
- 'invokeai/frontend/web/**'
|
|
||||||
merge_group:
|
|
||||||
workflow_dispatch:
|
|
||||||
|
|
||||||
concurrency:
|
|
||||||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
|
||||||
cancel-in-progress: true
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
matrix:
|
|
||||||
if: github.event.pull_request.draft == false
|
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
python-version:
|
|
||||||
- '3.10'
|
|
||||||
pytorch:
|
|
||||||
- linux-cuda-11_7
|
|
||||||
- linux-rocm-5_2
|
|
||||||
- linux-cpu
|
|
||||||
- macos-default
|
|
||||||
- windows-cpu
|
|
||||||
include:
|
|
||||||
- pytorch: linux-cuda-11_7
|
|
||||||
os: ubuntu-22.04
|
|
||||||
- pytorch: linux-rocm-5_2
|
|
||||||
os: ubuntu-22.04
|
|
||||||
- pytorch: linux-cpu
|
|
||||||
os: ubuntu-22.04
|
|
||||||
- pytorch: macos-default
|
|
||||||
os: macOS-12
|
|
||||||
- pytorch: windows-cpu
|
|
||||||
os: windows-2022
|
|
||||||
name: ${{ matrix.pytorch }} on ${{ matrix.python-version }}
|
|
||||||
runs-on: ${{ matrix.os }}
|
|
||||||
steps:
|
|
||||||
- name: skip
|
|
||||||
run: echo "no build required"
|
|
||||||
24
.github/workflows/test-invoke-pip.yml
vendored
24
.github/workflows/test-invoke-pip.yml
vendored
@@ -3,16 +3,7 @@ on:
|
|||||||
push:
|
push:
|
||||||
branches:
|
branches:
|
||||||
- 'main'
|
- 'main'
|
||||||
paths:
|
|
||||||
- 'pyproject.toml'
|
|
||||||
- 'invokeai/**'
|
|
||||||
- '!invokeai/frontend/web/**'
|
|
||||||
pull_request:
|
pull_request:
|
||||||
paths:
|
|
||||||
- 'pyproject.toml'
|
|
||||||
- 'invokeai/**'
|
|
||||||
- 'tests/**'
|
|
||||||
- '!invokeai/frontend/web/**'
|
|
||||||
types:
|
types:
|
||||||
- 'ready_for_review'
|
- 'ready_for_review'
|
||||||
- 'opened'
|
- 'opened'
|
||||||
@@ -65,10 +56,23 @@ jobs:
|
|||||||
id: checkout-sources
|
id: checkout-sources
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
|
- name: Check for changed python files
|
||||||
|
id: changed-files
|
||||||
|
uses: tj-actions/changed-files@v37
|
||||||
|
with:
|
||||||
|
files_yaml: |
|
||||||
|
python:
|
||||||
|
- 'pyproject.toml'
|
||||||
|
- 'invokeai/**'
|
||||||
|
- '!invokeai/frontend/web/**'
|
||||||
|
- 'tests/**'
|
||||||
|
|
||||||
- name: set test prompt to main branch validation
|
- name: set test prompt to main branch validation
|
||||||
|
if: steps.changed-files.outputs.python_any_changed == 'true'
|
||||||
run: echo "TEST_PROMPTS=tests/validate_pr_prompt.txt" >> ${{ matrix.github-env }}
|
run: echo "TEST_PROMPTS=tests/validate_pr_prompt.txt" >> ${{ matrix.github-env }}
|
||||||
|
|
||||||
- name: setup python
|
- name: setup python
|
||||||
|
if: steps.changed-files.outputs.python_any_changed == 'true'
|
||||||
uses: actions/setup-python@v4
|
uses: actions/setup-python@v4
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
@@ -76,6 +80,7 @@ jobs:
|
|||||||
cache-dependency-path: pyproject.toml
|
cache-dependency-path: pyproject.toml
|
||||||
|
|
||||||
- name: install invokeai
|
- name: install invokeai
|
||||||
|
if: steps.changed-files.outputs.python_any_changed == 'true'
|
||||||
env:
|
env:
|
||||||
PIP_EXTRA_INDEX_URL: ${{ matrix.extra-index-url }}
|
PIP_EXTRA_INDEX_URL: ${{ matrix.extra-index-url }}
|
||||||
run: >
|
run: >
|
||||||
@@ -83,6 +88,7 @@ jobs:
|
|||||||
--editable=".[test]"
|
--editable=".[test]"
|
||||||
|
|
||||||
- name: run pytest
|
- name: run pytest
|
||||||
|
if: steps.changed-files.outputs.python_any_changed == 'true'
|
||||||
id: run-pytest
|
id: run-pytest
|
||||||
run: pytest
|
run: pytest
|
||||||
|
|
||||||
|
|||||||
13
README.md
13
README.md
@@ -184,8 +184,9 @@ the command `npm install -g yarn` if needed)
|
|||||||
6. Configure InvokeAI and install a starting set of image generation models (you only need to do this once):
|
6. Configure InvokeAI and install a starting set of image generation models (you only need to do this once):
|
||||||
|
|
||||||
```terminal
|
```terminal
|
||||||
invokeai-configure
|
invokeai-configure --root .
|
||||||
```
|
```
|
||||||
|
Don't miss the dot at the end!
|
||||||
|
|
||||||
7. Launch the web server (do it every time you run InvokeAI):
|
7. Launch the web server (do it every time you run InvokeAI):
|
||||||
|
|
||||||
@@ -193,15 +194,9 @@ the command `npm install -g yarn` if needed)
|
|||||||
invokeai-web
|
invokeai-web
|
||||||
```
|
```
|
||||||
|
|
||||||
8. Build Node.js assets
|
8. Point your browser to http://localhost:9090 to bring up the web interface.
|
||||||
|
|
||||||
```terminal
|
9. Type `banana sushi` in the box on the top left and click `Invoke`.
|
||||||
cd invokeai/frontend/web/
|
|
||||||
yarn vite build
|
|
||||||
```
|
|
||||||
|
|
||||||
9. Point your browser to http://localhost:9090 to bring up the web interface.
|
|
||||||
10. Type `banana sushi` in the box on the top left and click `Invoke`.
|
|
||||||
|
|
||||||
Be sure to activate the virtual environment each time before re-launching InvokeAI,
|
Be sure to activate the virtual environment each time before re-launching InvokeAI,
|
||||||
using `source .venv/bin/activate` or `.venv\Scripts\activate`.
|
using `source .venv/bin/activate` or `.venv\Scripts\activate`.
|
||||||
|
|||||||
@@ -192,8 +192,10 @@ manager, please follow these steps:
|
|||||||
your outputs.
|
your outputs.
|
||||||
|
|
||||||
```terminal
|
```terminal
|
||||||
invokeai-configure
|
invokeai-configure --root .
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Don't miss the dot at the end of the command!
|
||||||
|
|
||||||
The script `invokeai-configure` will interactively guide you through the
|
The script `invokeai-configure` will interactively guide you through the
|
||||||
process of downloading and installing the weights files needed for InvokeAI.
|
process of downloading and installing the weights files needed for InvokeAI.
|
||||||
@@ -225,12 +227,6 @@ manager, please follow these steps:
|
|||||||
|
|
||||||
!!! warning "Make sure that the virtual environment is activated, which should create `(.venv)` in front of your prompt!"
|
!!! warning "Make sure that the virtual environment is activated, which should create `(.venv)` in front of your prompt!"
|
||||||
|
|
||||||
=== "CLI"
|
|
||||||
|
|
||||||
```bash
|
|
||||||
invokeai
|
|
||||||
```
|
|
||||||
|
|
||||||
=== "local Webserver"
|
=== "local Webserver"
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -243,6 +239,12 @@ manager, please follow these steps:
|
|||||||
invokeai --web --host 0.0.0.0
|
invokeai --web --host 0.0.0.0
|
||||||
```
|
```
|
||||||
|
|
||||||
|
=== "CLI"
|
||||||
|
|
||||||
|
```bash
|
||||||
|
invokeai
|
||||||
|
```
|
||||||
|
|
||||||
If you choose the run the web interface, point your browser at
|
If you choose the run the web interface, point your browser at
|
||||||
http://localhost:9090 in order to load the GUI.
|
http://localhost:9090 in order to load the GUI.
|
||||||
|
|
||||||
|
|||||||
@@ -124,7 +124,7 @@ installation. Examples:
|
|||||||
invokeai-model-install --list controlnet
|
invokeai-model-install --list controlnet
|
||||||
|
|
||||||
# (install the model at the indicated URL)
|
# (install the model at the indicated URL)
|
||||||
invokeai-model-install --add http://civitai.com/2860
|
invokeai-model-install --add https://civitai.com/api/download/models/128713
|
||||||
|
|
||||||
# (delete the named model)
|
# (delete the named model)
|
||||||
invokeai-model-install --delete sd-1/main/analog-diffusion
|
invokeai-model-install --delete sd-1/main/analog-diffusion
|
||||||
@@ -170,4 +170,4 @@ elsewhere on disk and they will be autoimported. You can also create
|
|||||||
subfolders and organize them as you wish.
|
subfolders and organize them as you wish.
|
||||||
|
|
||||||
The location of the autoimport directories are controlled by settings
|
The location of the autoimport directories are controlled by settings
|
||||||
in `invokeai.yaml`. See [Configuration](../features/CONFIGURATION.md).
|
in `invokeai.yaml`. See [Configuration](../features/CONFIGURATION.md).
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
import os
|
|
||||||
from invokeai.app.services.board_image_record_storage import (
|
from invokeai.app.services.board_image_record_storage import (
|
||||||
SqliteBoardImageRecordStorage,
|
SqliteBoardImageRecordStorage,
|
||||||
)
|
)
|
||||||
@@ -30,6 +29,7 @@ from ..services.invoker import Invoker
|
|||||||
from ..services.processor import DefaultInvocationProcessor
|
from ..services.processor import DefaultInvocationProcessor
|
||||||
from ..services.sqlite import SqliteItemStorage
|
from ..services.sqlite import SqliteItemStorage
|
||||||
from ..services.model_manager_service import ModelManagerService
|
from ..services.model_manager_service import ModelManagerService
|
||||||
|
from ..services.invocation_stats import InvocationStatsService
|
||||||
from .events import FastAPIEventService
|
from .events import FastAPIEventService
|
||||||
|
|
||||||
|
|
||||||
@@ -55,7 +55,7 @@ logger = InvokeAILogger.getLogger()
|
|||||||
class ApiDependencies:
|
class ApiDependencies:
|
||||||
"""Contains and initializes all dependencies for the API"""
|
"""Contains and initializes all dependencies for the API"""
|
||||||
|
|
||||||
invoker: Optional[Invoker] = None
|
invoker: Invoker
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger = logger):
|
def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger = logger):
|
||||||
@@ -68,8 +68,9 @@ class ApiDependencies:
|
|||||||
output_folder = config.output_path
|
output_folder = config.output_path
|
||||||
|
|
||||||
# TODO: build a file/path manager?
|
# TODO: build a file/path manager?
|
||||||
db_location = config.db_path
|
db_path = config.db_path
|
||||||
db_location.parent.mkdir(parents=True, exist_ok=True)
|
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
db_location = str(db_path)
|
||||||
|
|
||||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||||
filename=db_location, table_name="graph_executions"
|
filename=db_location, table_name="graph_executions"
|
||||||
@@ -128,6 +129,7 @@ class ApiDependencies:
|
|||||||
graph_execution_manager=graph_execution_manager,
|
graph_execution_manager=graph_execution_manager,
|
||||||
processor=DefaultInvocationProcessor(),
|
processor=DefaultInvocationProcessor(),
|
||||||
configuration=config,
|
configuration=config,
|
||||||
|
performance_statistics=InvocationStatsService(graph_execution_manager),
|
||||||
logger=logger,
|
logger=logger,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,24 +1,30 @@
|
|||||||
from fastapi import Body, HTTPException, Path, Query
|
from fastapi import Body, HTTPException
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
from invokeai.app.services.board_record_storage import BoardRecord, BoardChanges
|
from pydantic import BaseModel, Field
|
||||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
|
||||||
from invokeai.app.services.models.board_record import BoardDTO
|
|
||||||
from invokeai.app.services.models.image_record import ImageDTO
|
|
||||||
|
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
|
|
||||||
board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"])
|
board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"])
|
||||||
|
|
||||||
|
|
||||||
|
class AddImagesToBoardResult(BaseModel):
|
||||||
|
board_id: str = Field(description="The id of the board the images were added to")
|
||||||
|
added_image_names: list[str] = Field(description="The image names that were added to the board")
|
||||||
|
|
||||||
|
|
||||||
|
class RemoveImagesFromBoardResult(BaseModel):
|
||||||
|
removed_image_names: list[str] = Field(description="The image names that were removed from their board")
|
||||||
|
|
||||||
|
|
||||||
@board_images_router.post(
|
@board_images_router.post(
|
||||||
"/",
|
"/",
|
||||||
operation_id="create_board_image",
|
operation_id="add_image_to_board",
|
||||||
responses={
|
responses={
|
||||||
201: {"description": "The image was added to a board successfully"},
|
201: {"description": "The image was added to a board successfully"},
|
||||||
},
|
},
|
||||||
status_code=201,
|
status_code=201,
|
||||||
)
|
)
|
||||||
async def create_board_image(
|
async def add_image_to_board(
|
||||||
board_id: str = Body(description="The id of the board to add to"),
|
board_id: str = Body(description="The id of the board to add to"),
|
||||||
image_name: str = Body(description="The name of the image to add"),
|
image_name: str = Body(description="The name of the image to add"),
|
||||||
):
|
):
|
||||||
@@ -29,26 +35,78 @@ async def create_board_image(
|
|||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail="Failed to add to board")
|
raise HTTPException(status_code=500, detail="Failed to add image to board")
|
||||||
|
|
||||||
|
|
||||||
@board_images_router.delete(
|
@board_images_router.delete(
|
||||||
"/",
|
"/",
|
||||||
operation_id="remove_board_image",
|
operation_id="remove_image_from_board",
|
||||||
responses={
|
responses={
|
||||||
201: {"description": "The image was removed from the board successfully"},
|
201: {"description": "The image was removed from the board successfully"},
|
||||||
},
|
},
|
||||||
status_code=201,
|
status_code=201,
|
||||||
)
|
)
|
||||||
async def remove_board_image(
|
async def remove_image_from_board(
|
||||||
board_id: str = Body(description="The id of the board"),
|
image_name: str = Body(description="The name of the image to remove", embed=True),
|
||||||
image_name: str = Body(description="The name of the image to remove"),
|
|
||||||
):
|
):
|
||||||
"""Deletes a board_image"""
|
"""Removes an image from its board, if it had one"""
|
||||||
try:
|
try:
|
||||||
result = ApiDependencies.invoker.services.board_images.remove_image_from_board(
|
result = ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
|
||||||
board_id=board_id, image_name=image_name
|
|
||||||
)
|
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail="Failed to update board")
|
raise HTTPException(status_code=500, detail="Failed to remove image from board")
|
||||||
|
|
||||||
|
|
||||||
|
@board_images_router.post(
|
||||||
|
"/batch",
|
||||||
|
operation_id="add_images_to_board",
|
||||||
|
responses={
|
||||||
|
201: {"description": "Images were added to board successfully"},
|
||||||
|
},
|
||||||
|
status_code=201,
|
||||||
|
response_model=AddImagesToBoardResult,
|
||||||
|
)
|
||||||
|
async def add_images_to_board(
|
||||||
|
board_id: str = Body(description="The id of the board to add to"),
|
||||||
|
image_names: list[str] = Body(description="The names of the images to add", embed=True),
|
||||||
|
) -> AddImagesToBoardResult:
|
||||||
|
"""Adds a list of images to a board"""
|
||||||
|
try:
|
||||||
|
added_image_names: list[str] = []
|
||||||
|
for image_name in image_names:
|
||||||
|
try:
|
||||||
|
ApiDependencies.invoker.services.board_images.add_image_to_board(
|
||||||
|
board_id=board_id, image_name=image_name
|
||||||
|
)
|
||||||
|
added_image_names.append(image_name)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return AddImagesToBoardResult(board_id=board_id, added_image_names=added_image_names)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to add images to board")
|
||||||
|
|
||||||
|
|
||||||
|
@board_images_router.post(
|
||||||
|
"/batch/delete",
|
||||||
|
operation_id="remove_images_from_board",
|
||||||
|
responses={
|
||||||
|
201: {"description": "Images were removed from board successfully"},
|
||||||
|
},
|
||||||
|
status_code=201,
|
||||||
|
response_model=RemoveImagesFromBoardResult,
|
||||||
|
)
|
||||||
|
async def remove_images_from_board(
|
||||||
|
image_names: list[str] = Body(description="The names of the images to remove", embed=True),
|
||||||
|
) -> RemoveImagesFromBoardResult:
|
||||||
|
"""Removes a list of images from their board, if they had one"""
|
||||||
|
try:
|
||||||
|
removed_image_names: list[str] = []
|
||||||
|
for image_name in image_names:
|
||||||
|
try:
|
||||||
|
ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
|
||||||
|
removed_image_names.append(image_name)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return RemoveImagesFromBoardResult(removed_image_names=removed_image_names)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to remove images from board")
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadF
|
|||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from invokeai.app.invocations.metadata import ImageMetadata
|
from invokeai.app.invocations.metadata import ImageMetadata
|
||||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||||
@@ -25,7 +26,7 @@ IMAGE_MAX_AGE = 31536000
|
|||||||
|
|
||||||
|
|
||||||
@images_router.post(
|
@images_router.post(
|
||||||
"/",
|
"/upload",
|
||||||
operation_id="upload_image",
|
operation_id="upload_image",
|
||||||
responses={
|
responses={
|
||||||
201: {"description": "The image was uploaded successfully"},
|
201: {"description": "The image was uploaded successfully"},
|
||||||
@@ -77,7 +78,7 @@ async def upload_image(
|
|||||||
raise HTTPException(status_code=500, detail="Failed to create image")
|
raise HTTPException(status_code=500, detail="Failed to create image")
|
||||||
|
|
||||||
|
|
||||||
@images_router.delete("/{image_name}", operation_id="delete_image")
|
@images_router.delete("/i/{image_name}", operation_id="delete_image")
|
||||||
async def delete_image(
|
async def delete_image(
|
||||||
image_name: str = Path(description="The name of the image to delete"),
|
image_name: str = Path(description="The name of the image to delete"),
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -103,7 +104,7 @@ async def clear_intermediates() -> int:
|
|||||||
|
|
||||||
|
|
||||||
@images_router.patch(
|
@images_router.patch(
|
||||||
"/{image_name}",
|
"/i/{image_name}",
|
||||||
operation_id="update_image",
|
operation_id="update_image",
|
||||||
response_model=ImageDTO,
|
response_model=ImageDTO,
|
||||||
)
|
)
|
||||||
@@ -120,7 +121,7 @@ async def update_image(
|
|||||||
|
|
||||||
|
|
||||||
@images_router.get(
|
@images_router.get(
|
||||||
"/{image_name}",
|
"/i/{image_name}",
|
||||||
operation_id="get_image_dto",
|
operation_id="get_image_dto",
|
||||||
response_model=ImageDTO,
|
response_model=ImageDTO,
|
||||||
)
|
)
|
||||||
@@ -136,7 +137,7 @@ async def get_image_dto(
|
|||||||
|
|
||||||
|
|
||||||
@images_router.get(
|
@images_router.get(
|
||||||
"/{image_name}/metadata",
|
"/i/{image_name}/metadata",
|
||||||
operation_id="get_image_metadata",
|
operation_id="get_image_metadata",
|
||||||
response_model=ImageMetadata,
|
response_model=ImageMetadata,
|
||||||
)
|
)
|
||||||
@@ -152,7 +153,7 @@ async def get_image_metadata(
|
|||||||
|
|
||||||
|
|
||||||
@images_router.get(
|
@images_router.get(
|
||||||
"/{image_name}/full",
|
"/i/{image_name}/full",
|
||||||
operation_id="get_image_full",
|
operation_id="get_image_full",
|
||||||
response_class=Response,
|
response_class=Response,
|
||||||
responses={
|
responses={
|
||||||
@@ -187,7 +188,7 @@ async def get_image_full(
|
|||||||
|
|
||||||
|
|
||||||
@images_router.get(
|
@images_router.get(
|
||||||
"/{image_name}/thumbnail",
|
"/i/{image_name}/thumbnail",
|
||||||
operation_id="get_image_thumbnail",
|
operation_id="get_image_thumbnail",
|
||||||
response_class=Response,
|
response_class=Response,
|
||||||
responses={
|
responses={
|
||||||
@@ -216,7 +217,7 @@ async def get_image_thumbnail(
|
|||||||
|
|
||||||
|
|
||||||
@images_router.get(
|
@images_router.get(
|
||||||
"/{image_name}/urls",
|
"/i/{image_name}/urls",
|
||||||
operation_id="get_image_urls",
|
operation_id="get_image_urls",
|
||||||
response_model=ImageUrlsDTO,
|
response_model=ImageUrlsDTO,
|
||||||
)
|
)
|
||||||
@@ -265,3 +266,24 @@ async def list_image_dtos(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return image_dtos
|
return image_dtos
|
||||||
|
|
||||||
|
|
||||||
|
class DeleteImagesFromListResult(BaseModel):
|
||||||
|
deleted_images: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
@images_router.post("/delete", operation_id="delete_images_from_list", response_model=DeleteImagesFromListResult)
|
||||||
|
async def delete_images_from_list(
|
||||||
|
image_names: list[str] = Body(description="The list of names of images to delete", embed=True),
|
||||||
|
) -> DeleteImagesFromListResult:
|
||||||
|
try:
|
||||||
|
deleted_images: list[str] = []
|
||||||
|
for image_name in image_names:
|
||||||
|
try:
|
||||||
|
ApiDependencies.invoker.services.images.delete(image_name)
|
||||||
|
deleted_images.append(image_name)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return DeleteImagesFromListResult(deleted_images=deleted_images)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to delete images")
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
|||||||
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
||||||
from invokeai.app.services.resource_name import SimpleNameService
|
from invokeai.app.services.resource_name import SimpleNameService
|
||||||
from invokeai.app.services.urls import LocalUrlService
|
from invokeai.app.services.urls import LocalUrlService
|
||||||
|
from invokeai.app.services.invocation_stats import InvocationStatsService
|
||||||
from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs
|
from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs
|
||||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||||
|
|
||||||
@@ -311,6 +312,7 @@ def invoke_cli():
|
|||||||
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
|
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
|
||||||
graph_execution_manager=graph_execution_manager,
|
graph_execution_manager=graph_execution_manager,
|
||||||
processor=DefaultInvocationProcessor(),
|
processor=DefaultInvocationProcessor(),
|
||||||
|
performance_statistics=InvocationStatsService(graph_execution_manager),
|
||||||
logger=logger,
|
logger=logger,
|
||||||
configuration=config,
|
configuration=config,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -109,12 +109,15 @@ class CompelInvocation(BaseInvocation):
|
|||||||
name = trigger[1:-1]
|
name = trigger[1:-1]
|
||||||
try:
|
try:
|
||||||
ti_list.append(
|
ti_list.append(
|
||||||
context.services.model_manager.get_model(
|
(
|
||||||
model_name=name,
|
name,
|
||||||
base_model=self.clip.text_encoder.base_model,
|
context.services.model_manager.get_model(
|
||||||
model_type=ModelType.TextualInversion,
|
model_name=name,
|
||||||
context=context,
|
base_model=self.clip.text_encoder.base_model,
|
||||||
).context.model
|
model_type=ModelType.TextualInversion,
|
||||||
|
context=context,
|
||||||
|
).context.model,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
except ModelNotFoundException:
|
except ModelNotFoundException:
|
||||||
# print(e)
|
# print(e)
|
||||||
@@ -173,7 +176,7 @@ class CompelInvocation(BaseInvocation):
|
|||||||
|
|
||||||
|
|
||||||
class SDXLPromptInvocationBase:
|
class SDXLPromptInvocationBase:
|
||||||
def run_clip_raw(self, context, clip_field, prompt, get_pooled):
|
def run_clip_raw(self, context, clip_field, prompt, get_pooled, lora_prefix):
|
||||||
tokenizer_info = context.services.model_manager.get_model(
|
tokenizer_info = context.services.model_manager.get_model(
|
||||||
**clip_field.tokenizer.dict(),
|
**clip_field.tokenizer.dict(),
|
||||||
context=context,
|
context=context,
|
||||||
@@ -197,12 +200,15 @@ class SDXLPromptInvocationBase:
|
|||||||
name = trigger[1:-1]
|
name = trigger[1:-1]
|
||||||
try:
|
try:
|
||||||
ti_list.append(
|
ti_list.append(
|
||||||
context.services.model_manager.get_model(
|
(
|
||||||
model_name=name,
|
name,
|
||||||
base_model=clip_field.text_encoder.base_model,
|
context.services.model_manager.get_model(
|
||||||
model_type=ModelType.TextualInversion,
|
model_name=name,
|
||||||
context=context,
|
base_model=clip_field.text_encoder.base_model,
|
||||||
).context.model
|
model_type=ModelType.TextualInversion,
|
||||||
|
context=context,
|
||||||
|
).context.model,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
except ModelNotFoundException:
|
except ModelNotFoundException:
|
||||||
# print(e)
|
# print(e)
|
||||||
@@ -210,8 +216,8 @@ class SDXLPromptInvocationBase:
|
|||||||
# print(traceback.format_exc())
|
# print(traceback.format_exc())
|
||||||
print(f'Warn: trigger: "{trigger}" not found')
|
print(f'Warn: trigger: "{trigger}" not found')
|
||||||
|
|
||||||
with ModelPatcher.apply_lora_text_encoder(
|
with ModelPatcher.apply_lora(
|
||||||
text_encoder_info.context.model, _lora_loader()
|
text_encoder_info.context.model, _lora_loader(), lora_prefix
|
||||||
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
||||||
tokenizer,
|
tokenizer,
|
||||||
ti_manager,
|
ti_manager,
|
||||||
@@ -247,7 +253,7 @@ class SDXLPromptInvocationBase:
|
|||||||
|
|
||||||
return c, c_pooled, None
|
return c, c_pooled, None
|
||||||
|
|
||||||
def run_clip_compel(self, context, clip_field, prompt, get_pooled):
|
def run_clip_compel(self, context, clip_field, prompt, get_pooled, lora_prefix):
|
||||||
tokenizer_info = context.services.model_manager.get_model(
|
tokenizer_info = context.services.model_manager.get_model(
|
||||||
**clip_field.tokenizer.dict(),
|
**clip_field.tokenizer.dict(),
|
||||||
context=context,
|
context=context,
|
||||||
@@ -271,12 +277,15 @@ class SDXLPromptInvocationBase:
|
|||||||
name = trigger[1:-1]
|
name = trigger[1:-1]
|
||||||
try:
|
try:
|
||||||
ti_list.append(
|
ti_list.append(
|
||||||
context.services.model_manager.get_model(
|
(
|
||||||
model_name=name,
|
name,
|
||||||
base_model=clip_field.text_encoder.base_model,
|
context.services.model_manager.get_model(
|
||||||
model_type=ModelType.TextualInversion,
|
model_name=name,
|
||||||
context=context,
|
base_model=clip_field.text_encoder.base_model,
|
||||||
).context.model
|
model_type=ModelType.TextualInversion,
|
||||||
|
context=context,
|
||||||
|
).context.model,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
except ModelNotFoundException:
|
except ModelNotFoundException:
|
||||||
# print(e)
|
# print(e)
|
||||||
@@ -284,8 +293,8 @@ class SDXLPromptInvocationBase:
|
|||||||
# print(traceback.format_exc())
|
# print(traceback.format_exc())
|
||||||
print(f'Warn: trigger: "{trigger}" not found')
|
print(f'Warn: trigger: "{trigger}" not found')
|
||||||
|
|
||||||
with ModelPatcher.apply_lora_text_encoder(
|
with ModelPatcher.apply_lora(
|
||||||
text_encoder_info.context.model, _lora_loader()
|
text_encoder_info.context.model, _lora_loader(), lora_prefix
|
||||||
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
||||||
tokenizer,
|
tokenizer,
|
||||||
ti_manager,
|
ti_manager,
|
||||||
@@ -357,11 +366,11 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||||
c1, c1_pooled, ec1 = self.run_clip_compel(context, self.clip, self.prompt, False)
|
c1, c1_pooled, ec1 = self.run_clip_compel(context, self.clip, self.prompt, False, "lora_te1_")
|
||||||
if self.style.strip() == "":
|
if self.style.strip() == "":
|
||||||
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.prompt, True)
|
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.prompt, True, "lora_te2_")
|
||||||
else:
|
else:
|
||||||
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True)
|
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "lora_te2_")
|
||||||
|
|
||||||
original_size = (self.original_height, self.original_width)
|
original_size = (self.original_height, self.original_width)
|
||||||
crop_coords = (self.crop_top, self.crop_left)
|
crop_coords = (self.crop_top, self.crop_left)
|
||||||
@@ -415,7 +424,8 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||||
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True)
|
# TODO: if there will appear lora for refiner - write proper prefix
|
||||||
|
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "<NONE>")
|
||||||
|
|
||||||
original_size = (self.original_height, self.original_width)
|
original_size = (self.original_height, self.original_width)
|
||||||
crop_coords = (self.crop_top, self.crop_left)
|
crop_coords = (self.crop_top, self.crop_left)
|
||||||
@@ -467,11 +477,11 @@ class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||||
c1, c1_pooled, ec1 = self.run_clip_raw(context, self.clip, self.prompt, False)
|
c1, c1_pooled, ec1 = self.run_clip_raw(context, self.clip, self.prompt, False, "lora_te1_")
|
||||||
if self.style.strip() == "":
|
if self.style.strip() == "":
|
||||||
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.prompt, True)
|
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.prompt, True, "lora_te2_")
|
||||||
else:
|
else:
|
||||||
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True)
|
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True, "lora_te2_")
|
||||||
|
|
||||||
original_size = (self.original_height, self.original_width)
|
original_size = (self.original_height, self.original_width)
|
||||||
crop_coords = (self.crop_top, self.crop_left)
|
crop_coords = (self.crop_top, self.crop_left)
|
||||||
@@ -525,7 +535,8 @@ class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||||
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True)
|
# TODO: if there will appear lora for refiner - write proper prefix
|
||||||
|
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True, "<NONE>")
|
||||||
|
|
||||||
original_size = (self.original_height, self.original_width)
|
original_size = (self.original_height, self.original_width)
|
||||||
crop_coords = (self.crop_top, self.crop_left)
|
crop_coords = (self.crop_top, self.crop_left)
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
from typing import Literal, Optional
|
from typing import Literal, Optional
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
|
import cv2
|
||||||
from PIL import Image, ImageFilter, ImageOps, ImageChops
|
from PIL import Image, ImageFilter, ImageOps, ImageChops
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -650,3 +651,147 @@ class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageHueAdjustmentInvocation(BaseInvocation):
|
||||||
|
"""Adjusts the Hue of an image."""
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["img_hue_adjust"] = "img_hue_adjust"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
image: ImageField = Field(default=None, description="The image to adjust")
|
||||||
|
hue: int = Field(default=0, description="The degrees by which to rotate the hue")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
pil_image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
|
|
||||||
|
# Convert PIL image to OpenCV format (numpy array), note color channel
|
||||||
|
# ordering is changed from RGB to BGR
|
||||||
|
image = numpy.array(pil_image.convert("RGB"))[:, :, ::-1]
|
||||||
|
|
||||||
|
# Convert image to HSV color space
|
||||||
|
hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
|
||||||
|
|
||||||
|
# Adjust the hue
|
||||||
|
hsv_image[:, :, 0] = (hsv_image[:, :, 0] + self.hue) % 180
|
||||||
|
|
||||||
|
# Convert image back to BGR color space
|
||||||
|
image = cv2.cvtColor(hsv_image, cv2.COLOR_HSV2BGR)
|
||||||
|
|
||||||
|
# Convert back to PIL format and to original color mode
|
||||||
|
pil_image = Image.fromarray(image[:, :, ::-1], "RGB").convert("RGBA")
|
||||||
|
|
||||||
|
image_dto = context.services.images.create(
|
||||||
|
image=pil_image,
|
||||||
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
|
image_category=ImageCategory.GENERAL,
|
||||||
|
node_id=self.id,
|
||||||
|
is_intermediate=self.is_intermediate,
|
||||||
|
session_id=context.graph_execution_state_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ImageOutput(
|
||||||
|
image=ImageField(
|
||||||
|
image_name=image_dto.image_name,
|
||||||
|
),
|
||||||
|
width=image_dto.width,
|
||||||
|
height=image_dto.height,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageLuminosityAdjustmentInvocation(BaseInvocation):
|
||||||
|
"""Adjusts the Luminosity (Value) of an image."""
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["img_luminosity_adjust"] = "img_luminosity_adjust"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
image: ImageField = Field(default=None, description="The image to adjust")
|
||||||
|
luminosity: float = Field(default=1.0, ge=0, le=1, description="The factor by which to adjust the luminosity (value)")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
pil_image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
|
|
||||||
|
# Convert PIL image to OpenCV format (numpy array), note color channel
|
||||||
|
# ordering is changed from RGB to BGR
|
||||||
|
image = numpy.array(pil_image.convert("RGB"))[:, :, ::-1]
|
||||||
|
|
||||||
|
# Convert image to HSV color space
|
||||||
|
hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
|
||||||
|
|
||||||
|
# Adjust the luminosity (value)
|
||||||
|
hsv_image[:, :, 2] = numpy.clip(hsv_image[:, :, 2] * self.luminosity, 0, 255)
|
||||||
|
|
||||||
|
# Convert image back to BGR color space
|
||||||
|
image = cv2.cvtColor(hsv_image, cv2.COLOR_HSV2BGR)
|
||||||
|
|
||||||
|
# Convert back to PIL format and to original color mode
|
||||||
|
pil_image = Image.fromarray(image[:, :, ::-1], "RGB").convert("RGBA")
|
||||||
|
|
||||||
|
image_dto = context.services.images.create(
|
||||||
|
image=pil_image,
|
||||||
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
|
image_category=ImageCategory.GENERAL,
|
||||||
|
node_id=self.id,
|
||||||
|
is_intermediate=self.is_intermediate,
|
||||||
|
session_id=context.graph_execution_state_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ImageOutput(
|
||||||
|
image=ImageField(
|
||||||
|
image_name=image_dto.image_name,
|
||||||
|
),
|
||||||
|
width=image_dto.width,
|
||||||
|
height=image_dto.height,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageSaturationAdjustmentInvocation(BaseInvocation):
|
||||||
|
"""Adjusts the Saturation of an image."""
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["img_saturation_adjust"] = "img_saturation_adjust"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
image: ImageField = Field(default=None, description="The image to adjust")
|
||||||
|
saturation: float = Field(default=1.0, ge=0, le=1, description="The factor by which to adjust the saturation")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
pil_image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
|
|
||||||
|
# Convert PIL image to OpenCV format (numpy array), note color channel
|
||||||
|
# ordering is changed from RGB to BGR
|
||||||
|
image = numpy.array(pil_image.convert("RGB"))[:, :, ::-1]
|
||||||
|
|
||||||
|
# Convert image to HSV color space
|
||||||
|
hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
|
||||||
|
|
||||||
|
# Adjust the saturation
|
||||||
|
hsv_image[:, :, 1] = numpy.clip(hsv_image[:, :, 1] * self.saturation, 0, 255)
|
||||||
|
|
||||||
|
# Convert image back to BGR color space
|
||||||
|
image = cv2.cvtColor(hsv_image, cv2.COLOR_HSV2BGR)
|
||||||
|
|
||||||
|
# Convert back to PIL format and to original color mode
|
||||||
|
pil_image = Image.fromarray(image[:, :, ::-1], "RGB").convert("RGBA")
|
||||||
|
|
||||||
|
image_dto = context.services.images.create(
|
||||||
|
image=pil_image,
|
||||||
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
|
image_category=ImageCategory.GENERAL,
|
||||||
|
node_id=self.id,
|
||||||
|
is_intermediate=self.is_intermediate,
|
||||||
|
session_id=context.graph_execution_state_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ImageOutput(
|
||||||
|
image=ImageField(
|
||||||
|
image_name=image_dto.image_name,
|
||||||
|
),
|
||||||
|
width=image_dto.width,
|
||||||
|
height=image_dto.height,
|
||||||
|
)
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from invokeai.app.invocations.metadata import CoreMetadata
|
|||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
||||||
|
|
||||||
from ...backend.model_management.lora import ModelPatcher
|
from ...backend.model_management import ModelPatcher
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||||
ConditioningData,
|
ConditioningData,
|
||||||
@@ -180,6 +180,10 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||||
uc = negative_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype)
|
uc = negative_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype)
|
||||||
|
|
||||||
|
# for ancestral and sde schedulers
|
||||||
|
generator = torch.Generator(device=unet.device)
|
||||||
|
generator.seed()
|
||||||
|
|
||||||
conditioning_data = ConditioningData(
|
conditioning_data = ConditioningData(
|
||||||
unconditioned_embeddings=uc,
|
unconditioned_embeddings=uc,
|
||||||
text_embeddings=c,
|
text_embeddings=c,
|
||||||
@@ -198,7 +202,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
# for ddim scheduler
|
# for ddim scheduler
|
||||||
eta=0.0, # ddim_eta
|
eta=0.0, # ddim_eta
|
||||||
# for ancestral and sde schedulers
|
# for ancestral and sde schedulers
|
||||||
generator=torch.Generator(device=unet.device).manual_seed(0),
|
generator=generator,
|
||||||
)
|
)
|
||||||
return conditioning_data
|
return conditioning_data
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from typing import Literal, Optional, Union
|
from typing import Literal, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import Field
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import (
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
@@ -10,16 +10,17 @@ from invokeai.app.invocations.baseinvocation import (
|
|||||||
)
|
)
|
||||||
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||||
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
|
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
|
||||||
|
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||||
|
|
||||||
|
|
||||||
class LoRAMetadataField(BaseModel):
|
class LoRAMetadataField(BaseModelExcludeNull):
|
||||||
"""LoRA metadata for an image generated in InvokeAI."""
|
"""LoRA metadata for an image generated in InvokeAI."""
|
||||||
|
|
||||||
lora: LoRAModelField = Field(description="The LoRA model")
|
lora: LoRAModelField = Field(description="The LoRA model")
|
||||||
weight: float = Field(description="The weight of the LoRA model")
|
weight: float = Field(description="The weight of the LoRA model")
|
||||||
|
|
||||||
|
|
||||||
class CoreMetadata(BaseModel):
|
class CoreMetadata(BaseModelExcludeNull):
|
||||||
"""Core generation metadata for an image generated in InvokeAI."""
|
"""Core generation metadata for an image generated in InvokeAI."""
|
||||||
|
|
||||||
generation_mode: str = Field(
|
generation_mode: str = Field(
|
||||||
@@ -70,7 +71,7 @@ class CoreMetadata(BaseModel):
|
|||||||
refiner_start: Union[float, None] = Field(default=None, description="The start value used for refiner denoising")
|
refiner_start: Union[float, None] = Field(default=None, description="The start value used for refiner denoising")
|
||||||
|
|
||||||
|
|
||||||
class ImageMetadata(BaseModel):
|
class ImageMetadata(BaseModelExcludeNull):
|
||||||
"""An image's generation metadata"""
|
"""An image's generation metadata"""
|
||||||
|
|
||||||
metadata: Optional[dict] = Field(
|
metadata: Optional[dict] = Field(
|
||||||
|
|||||||
@@ -262,6 +262,103 @@ class LoraLoaderInvocation(BaseInvocation):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class SDXLLoraLoaderOutput(BaseInvocationOutput):
|
||||||
|
"""Model loader output"""
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["sdxl_lora_loader_output"] = "sdxl_lora_loader_output"
|
||||||
|
|
||||||
|
unet: Optional[UNetField] = Field(default=None, description="UNet submodel")
|
||||||
|
clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||||
|
clip2: Optional[ClipField] = Field(default=None, description="Tokenizer2 and text_encoder2 submodels")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||||
|
"""Apply selected lora to unet and text_encoder."""
|
||||||
|
|
||||||
|
type: Literal["sdxl_lora_loader"] = "sdxl_lora_loader"
|
||||||
|
|
||||||
|
lora: Union[LoRAModelField, None] = Field(default=None, description="Lora model name")
|
||||||
|
weight: float = Field(default=0.75, description="With what weight to apply lora")
|
||||||
|
|
||||||
|
unet: Optional[UNetField] = Field(description="UNet model for applying lora")
|
||||||
|
clip: Optional[ClipField] = Field(description="Clip model for applying lora")
|
||||||
|
clip2: Optional[ClipField] = Field(description="Clip2 model for applying lora")
|
||||||
|
|
||||||
|
class Config(InvocationConfig):
|
||||||
|
schema_extra = {
|
||||||
|
"ui": {
|
||||||
|
"title": "SDXL Lora Loader",
|
||||||
|
"tags": ["lora", "loader"],
|
||||||
|
"type_hints": {"lora": "lora_model"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
|
||||||
|
if self.lora is None:
|
||||||
|
raise Exception("No LoRA provided")
|
||||||
|
|
||||||
|
base_model = self.lora.base_model
|
||||||
|
lora_name = self.lora.model_name
|
||||||
|
|
||||||
|
if not context.services.model_manager.model_exists(
|
||||||
|
base_model=base_model,
|
||||||
|
model_name=lora_name,
|
||||||
|
model_type=ModelType.Lora,
|
||||||
|
):
|
||||||
|
raise Exception(f"Unknown lora name: {lora_name}!")
|
||||||
|
|
||||||
|
if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras):
|
||||||
|
raise Exception(f'Lora "{lora_name}" already applied to unet')
|
||||||
|
|
||||||
|
if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras):
|
||||||
|
raise Exception(f'Lora "{lora_name}" already applied to clip')
|
||||||
|
|
||||||
|
if self.clip2 is not None and any(lora.model_name == lora_name for lora in self.clip2.loras):
|
||||||
|
raise Exception(f'Lora "{lora_name}" already applied to clip2')
|
||||||
|
|
||||||
|
output = SDXLLoraLoaderOutput()
|
||||||
|
|
||||||
|
if self.unet is not None:
|
||||||
|
output.unet = copy.deepcopy(self.unet)
|
||||||
|
output.unet.loras.append(
|
||||||
|
LoraInfo(
|
||||||
|
base_model=base_model,
|
||||||
|
model_name=lora_name,
|
||||||
|
model_type=ModelType.Lora,
|
||||||
|
submodel=None,
|
||||||
|
weight=self.weight,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.clip is not None:
|
||||||
|
output.clip = copy.deepcopy(self.clip)
|
||||||
|
output.clip.loras.append(
|
||||||
|
LoraInfo(
|
||||||
|
base_model=base_model,
|
||||||
|
model_name=lora_name,
|
||||||
|
model_type=ModelType.Lora,
|
||||||
|
submodel=None,
|
||||||
|
weight=self.weight,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.clip2 is not None:
|
||||||
|
output.clip2 = copy.deepcopy(self.clip2)
|
||||||
|
output.clip2.loras.append(
|
||||||
|
LoraInfo(
|
||||||
|
base_model=base_model,
|
||||||
|
model_name=lora_name,
|
||||||
|
model_type=ModelType.Lora,
|
||||||
|
submodel=None,
|
||||||
|
weight=self.weight,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
class VAEModelField(BaseModel):
|
class VAEModelField(BaseModel):
|
||||||
"""Vae model field"""
|
"""Vae model field"""
|
||||||
|
|
||||||
|
|||||||
@@ -65,7 +65,6 @@ class ONNXPromptInvocation(BaseInvocation):
|
|||||||
**self.clip.text_encoder.dict(),
|
**self.clip.text_encoder.dict(),
|
||||||
)
|
)
|
||||||
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder, ExitStack() as stack:
|
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder, ExitStack() as stack:
|
||||||
# loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras]
|
|
||||||
loras = [
|
loras = [
|
||||||
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
|
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
|
||||||
for lora in self.clip.loras
|
for lora in self.clip.loras
|
||||||
@@ -76,18 +75,14 @@ class ONNXPromptInvocation(BaseInvocation):
|
|||||||
name = trigger[1:-1]
|
name = trigger[1:-1]
|
||||||
try:
|
try:
|
||||||
ti_list.append(
|
ti_list.append(
|
||||||
# stack.enter_context(
|
(
|
||||||
# context.services.model_manager.get_model(
|
name,
|
||||||
# model_name=name,
|
context.services.model_manager.get_model(
|
||||||
# base_model=self.clip.text_encoder.base_model,
|
model_name=name,
|
||||||
# model_type=ModelType.TextualInversion,
|
base_model=self.clip.text_encoder.base_model,
|
||||||
# )
|
model_type=ModelType.TextualInversion,
|
||||||
# )
|
).context.model,
|
||||||
context.services.model_manager.get_model(
|
)
|
||||||
model_name=name,
|
|
||||||
base_model=self.clip.text_encoder.base_model,
|
|
||||||
model_type=ModelType.TextualInversion,
|
|
||||||
).context.model
|
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
# print(e)
|
# print(e)
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from typing import List, Literal, Optional, Union
|
|||||||
|
|
||||||
from pydantic import Field, validator
|
from pydantic import Field, validator
|
||||||
|
|
||||||
from ...backend.model_management import ModelType, SubModelType
|
from ...backend.model_management import ModelType, SubModelType, ModelPatcher
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_xl_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_xl_step_callback
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||||
|
|
||||||
@@ -293,10 +293,20 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
num_inference_steps = self.steps
|
num_inference_steps = self.steps
|
||||||
|
|
||||||
|
def _lora_loader():
|
||||||
|
for lora in self.unet.loras:
|
||||||
|
lora_info = context.services.model_manager.get_model(
|
||||||
|
**lora.dict(exclude={"weight"}),
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
|
yield (lora_info.context.model, lora.weight)
|
||||||
|
del lora_info
|
||||||
|
return
|
||||||
|
|
||||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context)
|
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context)
|
||||||
do_classifier_free_guidance = True
|
do_classifier_free_guidance = True
|
||||||
cross_attention_kwargs = None
|
cross_attention_kwargs = None
|
||||||
with unet_info as unet:
|
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()), unet_info as unet:
|
||||||
scheduler.set_timesteps(num_inference_steps, device=unet.device)
|
scheduler.set_timesteps(num_inference_steps, device=unet.device)
|
||||||
timesteps = scheduler.timesteps
|
timesteps = scheduler.timesteps
|
||||||
|
|
||||||
@@ -543,9 +553,19 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
|||||||
context=context,
|
context=context,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _lora_loader():
|
||||||
|
for lora in self.unet.loras:
|
||||||
|
lora_info = context.services.model_manager.get_model(
|
||||||
|
**lora.dict(exclude={"weight"}),
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
|
yield (lora_info.context.model, lora.weight)
|
||||||
|
del lora_info
|
||||||
|
return
|
||||||
|
|
||||||
do_classifier_free_guidance = True
|
do_classifier_free_guidance = True
|
||||||
cross_attention_kwargs = None
|
cross_attention_kwargs = None
|
||||||
with unet_info as unet:
|
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()), unet_info as unet:
|
||||||
# apply denoising_start
|
# apply denoising_start
|
||||||
num_inference_steps = self.steps
|
num_inference_steps = self.steps
|
||||||
scheduler.set_timesteps(num_inference_steps, device=unet.device)
|
scheduler.set_timesteps(num_inference_steps, device=unet.device)
|
||||||
|
|||||||
@@ -25,7 +25,6 @@ class BoardImageRecordStorageBase(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def remove_image_from_board(
|
def remove_image_from_board(
|
||||||
self,
|
self,
|
||||||
board_id: str,
|
|
||||||
image_name: str,
|
image_name: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Removes an image from a board."""
|
"""Removes an image from a board."""
|
||||||
@@ -154,7 +153,6 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
|||||||
|
|
||||||
def remove_image_from_board(
|
def remove_image_from_board(
|
||||||
self,
|
self,
|
||||||
board_id: str,
|
|
||||||
image_name: str,
|
image_name: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
try:
|
try:
|
||||||
@@ -162,9 +160,9 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
|||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
DELETE FROM board_images
|
DELETE FROM board_images
|
||||||
WHERE board_id = ? AND image_name = ?;
|
WHERE image_name = ?;
|
||||||
""",
|
""",
|
||||||
(board_id, image_name),
|
(image_name,),
|
||||||
)
|
)
|
||||||
self._conn.commit()
|
self._conn.commit()
|
||||||
except sqlite3.Error as e:
|
except sqlite3.Error as e:
|
||||||
|
|||||||
@@ -31,7 +31,6 @@ class BoardImagesServiceABC(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def remove_image_from_board(
|
def remove_image_from_board(
|
||||||
self,
|
self,
|
||||||
board_id: str,
|
|
||||||
image_name: str,
|
image_name: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Removes an image from a board."""
|
"""Removes an image from a board."""
|
||||||
@@ -93,10 +92,9 @@ class BoardImagesService(BoardImagesServiceABC):
|
|||||||
|
|
||||||
def remove_image_from_board(
|
def remove_image_from_board(
|
||||||
self,
|
self,
|
||||||
board_id: str,
|
|
||||||
image_name: str,
|
image_name: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._services.board_image_records.remove_image_from_board(board_id, image_name)
|
self._services.board_image_records.remove_image_from_board(image_name)
|
||||||
|
|
||||||
def get_all_board_image_names_for_board(
|
def get_all_board_image_names_for_board(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -274,7 +274,7 @@ class InvokeAISettings(BaseSettings):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def _excluded(self) -> List[str]:
|
def _excluded(self) -> List[str]:
|
||||||
# internal fields that shouldn't be exposed as command line options
|
# internal fields that shouldn't be exposed as command line options
|
||||||
return ["type", "initconf", "cached_root"]
|
return ["type", "initconf"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _excluded_from_yaml(self) -> List[str]:
|
def _excluded_from_yaml(self) -> List[str]:
|
||||||
@@ -290,7 +290,6 @@ class InvokeAISettings(BaseSettings):
|
|||||||
"restore",
|
"restore",
|
||||||
"root",
|
"root",
|
||||||
"nsfw_checker",
|
"nsfw_checker",
|
||||||
"cached_root",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
@@ -356,7 +355,7 @@ class InvokeAISettings(BaseSettings):
|
|||||||
def _find_root() -> Path:
|
def _find_root() -> Path:
|
||||||
venv = Path(os.environ.get("VIRTUAL_ENV") or ".")
|
venv = Path(os.environ.get("VIRTUAL_ENV") or ".")
|
||||||
if os.environ.get("INVOKEAI_ROOT"):
|
if os.environ.get("INVOKEAI_ROOT"):
|
||||||
root = Path(os.environ.get("INVOKEAI_ROOT")).resolve()
|
root = Path(os.environ["INVOKEAI_ROOT"])
|
||||||
elif any([(venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE]]):
|
elif any([(venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE]]):
|
||||||
root = (venv.parent).resolve()
|
root = (venv.parent).resolve()
|
||||||
else:
|
else:
|
||||||
@@ -403,7 +402,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
|
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
|
||||||
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category='Memory/Performance')
|
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category='Memory/Performance')
|
||||||
|
|
||||||
root : Path = Field(default=_find_root(), description='InvokeAI runtime root directory', category='Paths')
|
root : Path = Field(default=None, description='InvokeAI runtime root directory', category='Paths')
|
||||||
autoimport_dir : Path = Field(default='autoimport', description='Path to a directory of models files to be imported on startup.', category='Paths')
|
autoimport_dir : Path = Field(default='autoimport', description='Path to a directory of models files to be imported on startup.', category='Paths')
|
||||||
lora_dir : Path = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', category='Paths')
|
lora_dir : Path = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', category='Paths')
|
||||||
embedding_dir : Path = Field(default=None, description='Path to a directory of Textual Inversion embeddings to be imported on startup.', category='Paths')
|
embedding_dir : Path = Field(default=None, description='Path to a directory of Textual Inversion embeddings to be imported on startup.', category='Paths')
|
||||||
@@ -415,6 +414,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths')
|
outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths')
|
||||||
from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths')
|
from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths')
|
||||||
use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', category='Paths')
|
use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', category='Paths')
|
||||||
|
ignore_missing_core_models : bool = Field(default=False, description='Ignore missing models in models/core/convert')
|
||||||
|
|
||||||
model : str = Field(default='stable-diffusion-1.5', description='Initial model name', category='Models')
|
model : str = Field(default='stable-diffusion-1.5', description='Initial model name', category='Models')
|
||||||
|
|
||||||
@@ -424,7 +424,6 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
log_level : Literal[tuple(["debug","info","warning","error","critical"])] = Field(default="info", description="Emit logging messages at this level or higher", category="Logging")
|
log_level : Literal[tuple(["debug","info","warning","error","critical"])] = Field(default="info", description="Emit logging messages at this level or higher", category="Logging")
|
||||||
|
|
||||||
version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other")
|
version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other")
|
||||||
cached_root : Path = Field(default=None, description="internal use only", category="DEPRECATED")
|
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def parse_args(self, argv: List[str] = None, conf: DictConfig = None, clobber=False):
|
def parse_args(self, argv: List[str] = None, conf: DictConfig = None, clobber=False):
|
||||||
@@ -472,15 +471,12 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
"""
|
"""
|
||||||
Path to the runtime root directory
|
Path to the runtime root directory
|
||||||
"""
|
"""
|
||||||
# we cache value of root to protect against it being '.' and the cwd changing
|
if self.root:
|
||||||
if self.cached_root:
|
|
||||||
root = self.cached_root
|
|
||||||
elif self.root:
|
|
||||||
root = Path(self.root).expanduser().absolute()
|
root = Path(self.root).expanduser().absolute()
|
||||||
else:
|
else:
|
||||||
root = self.find_root()
|
root = self.find_root().expanduser().absolute()
|
||||||
self.cached_root = root
|
self.root = root # insulate ourselves from relative paths that may change
|
||||||
return self.cached_root
|
return root
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def root_dir(self) -> Path:
|
def root_dir(self) -> Path:
|
||||||
|
|||||||
@@ -289,9 +289,10 @@ class ImageService(ImageServiceABC):
|
|||||||
def get_metadata(self, image_name: str) -> Optional[ImageMetadata]:
|
def get_metadata(self, image_name: str) -> Optional[ImageMetadata]:
|
||||||
try:
|
try:
|
||||||
image_record = self._services.image_records.get(image_name)
|
image_record = self._services.image_records.get(image_name)
|
||||||
|
metadata = self._services.image_records.get_metadata(image_name)
|
||||||
|
|
||||||
if not image_record.session_id:
|
if not image_record.session_id:
|
||||||
return ImageMetadata()
|
return ImageMetadata(metadata=metadata)
|
||||||
|
|
||||||
session_raw = self._services.graph_execution_manager.get_raw(image_record.session_id)
|
session_raw = self._services.graph_execution_manager.get_raw(image_record.session_id)
|
||||||
graph = None
|
graph = None
|
||||||
@@ -303,7 +304,6 @@ class ImageService(ImageServiceABC):
|
|||||||
self._services.logger.warn(f"Failed to parse session graph: {e}")
|
self._services.logger.warn(f"Failed to parse session graph: {e}")
|
||||||
graph = None
|
graph = None
|
||||||
|
|
||||||
metadata = self._services.image_records.get_metadata(image_name)
|
|
||||||
return ImageMetadata(graph=graph, metadata=metadata)
|
return ImageMetadata(graph=graph, metadata=metadata)
|
||||||
except ImageRecordNotFoundException:
|
except ImageRecordNotFoundException:
|
||||||
self._services.logger.error("Image record not found")
|
self._services.logger.error("Image record not found")
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ class InvocationServices:
|
|||||||
logger: "Logger"
|
logger: "Logger"
|
||||||
model_manager: "ModelManagerServiceBase"
|
model_manager: "ModelManagerServiceBase"
|
||||||
processor: "InvocationProcessorABC"
|
processor: "InvocationProcessorABC"
|
||||||
|
performance_statistics: "InvocationStatsServiceBase"
|
||||||
queue: "InvocationQueueABC"
|
queue: "InvocationQueueABC"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -47,6 +48,7 @@ class InvocationServices:
|
|||||||
logger: "Logger",
|
logger: "Logger",
|
||||||
model_manager: "ModelManagerServiceBase",
|
model_manager: "ModelManagerServiceBase",
|
||||||
processor: "InvocationProcessorABC",
|
processor: "InvocationProcessorABC",
|
||||||
|
performance_statistics: "InvocationStatsServiceBase",
|
||||||
queue: "InvocationQueueABC",
|
queue: "InvocationQueueABC",
|
||||||
):
|
):
|
||||||
self.board_images = board_images
|
self.board_images = board_images
|
||||||
@@ -61,4 +63,5 @@ class InvocationServices:
|
|||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.model_manager = model_manager
|
self.model_manager = model_manager
|
||||||
self.processor = processor
|
self.processor = processor
|
||||||
|
self.performance_statistics = performance_statistics
|
||||||
self.queue = queue
|
self.queue = queue
|
||||||
|
|||||||
223
invokeai/app/services/invocation_stats.py
Normal file
223
invokeai/app/services/invocation_stats.py
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
# Copyright 2023 Lincoln D. Stein <lincoln.stein@gmail.com>
|
||||||
|
"""Utility to collect execution time and GPU usage stats on invocations in flight"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
statistics = InvocationStatsService(graph_execution_manager)
|
||||||
|
with statistics.collect_stats(invocation, graph_execution_state.id):
|
||||||
|
... execute graphs...
|
||||||
|
statistics.log_stats()
|
||||||
|
|
||||||
|
Typical output:
|
||||||
|
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Graph stats: c7764585-9c68-4d9d-a199-55e8186790f3
|
||||||
|
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Node Calls Seconds VRAM Used
|
||||||
|
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> main_model_loader 1 0.005s 0.01G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> clip_skip 1 0.004s 0.01G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> compel 2 0.512s 0.26G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> rand_int 1 0.001s 0.01G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> range_of_size 1 0.001s 0.01G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> iterate 1 0.001s 0.01G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> metadata_accumulator 1 0.002s 0.01G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> noise 1 0.002s 0.01G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> t2l 1 3.541s 1.93G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> l2i 1 0.679s 0.58G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> TOTAL GRAPH EXECUTION TIME: 4.749s
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> Current VRAM utilization 0.01G
|
||||||
|
|
||||||
|
The abstract base class for this class is InvocationStatsServiceBase. An implementing class which
|
||||||
|
writes to the system log is stored in InvocationServices.performance_statistics.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from contextlib import AbstractContextManager
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import invokeai.backend.util.logging as logger
|
||||||
|
|
||||||
|
from ..invocations.baseinvocation import BaseInvocation
|
||||||
|
from .graph import GraphExecutionState
|
||||||
|
from .item_storage import ItemStorageABC
|
||||||
|
|
||||||
|
|
||||||
|
class InvocationStatsServiceBase(ABC):
|
||||||
|
"Abstract base class for recording node memory/time performance statistics"
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
|
||||||
|
"""
|
||||||
|
Initialize the InvocationStatsService and reset counters to zero
|
||||||
|
:param graph_execution_manager: Graph execution manager for this session
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def collect_stats(
|
||||||
|
self,
|
||||||
|
invocation: BaseInvocation,
|
||||||
|
graph_execution_state_id: str,
|
||||||
|
) -> AbstractContextManager:
|
||||||
|
"""
|
||||||
|
Return a context object that will capture the statistics on the execution
|
||||||
|
of invocaation. Use with: to place around the part of the code that executes the invocation.
|
||||||
|
:param invocation: BaseInvocation object from the current graph.
|
||||||
|
:param graph_execution_state: GraphExecutionState object from the current session.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def reset_stats(self, graph_execution_state_id: str):
|
||||||
|
"""
|
||||||
|
Reset all statistics for the indicated graph
|
||||||
|
:param graph_execution_state_id
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def reset_all_stats(self):
|
||||||
|
"""Zero all statistics"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update_invocation_stats(
|
||||||
|
self,
|
||||||
|
graph_id: str,
|
||||||
|
invocation_type: str,
|
||||||
|
time_used: float,
|
||||||
|
vram_used: float,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Add timing information on execution of a node. Usually
|
||||||
|
used internally.
|
||||||
|
:param graph_id: ID of the graph that is currently executing
|
||||||
|
:param invocation_type: String literal type of the node
|
||||||
|
:param time_used: Time used by node's exection (sec)
|
||||||
|
:param vram_used: Maximum VRAM used during exection (GB)
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def log_stats(self):
|
||||||
|
"""
|
||||||
|
Write out the accumulated statistics to the log or somewhere else.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class NodeStats:
|
||||||
|
"""Class for tracking execution stats of an invocation node"""
|
||||||
|
|
||||||
|
calls: int = 0
|
||||||
|
time_used: float = 0.0 # seconds
|
||||||
|
max_vram: float = 0.0 # GB
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class NodeLog:
|
||||||
|
"""Class for tracking node usage"""
|
||||||
|
|
||||||
|
# {node_type => NodeStats}
|
||||||
|
nodes: Dict[str, NodeStats] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class InvocationStatsService(InvocationStatsServiceBase):
|
||||||
|
"""Accumulate performance information about a running graph. Collects time spent in each node,
|
||||||
|
as well as the maximum and current VRAM utilisation for CUDA systems"""
|
||||||
|
|
||||||
|
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
|
||||||
|
self.graph_execution_manager = graph_execution_manager
|
||||||
|
# {graph_id => NodeLog}
|
||||||
|
self._stats: Dict[str, NodeLog] = {}
|
||||||
|
|
||||||
|
class StatsContext:
|
||||||
|
def __init__(self, invocation: BaseInvocation, graph_id: str, collector: "InvocationStatsServiceBase"):
|
||||||
|
self.invocation = invocation
|
||||||
|
self.collector = collector
|
||||||
|
self.graph_id = graph_id
|
||||||
|
self.start_time = 0
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.start_time = time.time()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
|
||||||
|
def __exit__(self, *args):
|
||||||
|
self.collector.update_invocation_stats(
|
||||||
|
self.graph_id,
|
||||||
|
self.invocation.type,
|
||||||
|
time.time() - self.start_time,
|
||||||
|
torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def collect_stats(
|
||||||
|
self,
|
||||||
|
invocation: BaseInvocation,
|
||||||
|
graph_execution_state_id: str,
|
||||||
|
) -> StatsContext:
|
||||||
|
"""
|
||||||
|
Return a context object that will capture the statistics.
|
||||||
|
:param invocation: BaseInvocation object from the current graph.
|
||||||
|
:param graph_execution_state: GraphExecutionState object from the current session.
|
||||||
|
"""
|
||||||
|
if not self._stats.get(graph_execution_state_id): # first time we're seeing this
|
||||||
|
self._stats[graph_execution_state_id] = NodeLog()
|
||||||
|
return self.StatsContext(invocation, graph_execution_state_id, self)
|
||||||
|
|
||||||
|
def reset_all_stats(self):
|
||||||
|
"""Zero all statistics"""
|
||||||
|
self._stats = {}
|
||||||
|
|
||||||
|
def reset_stats(self, graph_execution_id: str):
|
||||||
|
"""Zero the statistics for the indicated graph."""
|
||||||
|
try:
|
||||||
|
self._stats.pop(graph_execution_id)
|
||||||
|
except KeyError:
|
||||||
|
logger.warning(f"Attempted to clear statistics for unknown graph {graph_execution_id}")
|
||||||
|
|
||||||
|
def update_invocation_stats(self, graph_id: str, invocation_type: str, time_used: float, vram_used: float):
|
||||||
|
"""
|
||||||
|
Add timing information on execution of a node. Usually
|
||||||
|
used internally.
|
||||||
|
:param graph_id: ID of the graph that is currently executing
|
||||||
|
:param invocation_type: String literal type of the node
|
||||||
|
:param time_used: Floating point seconds used by node's exection
|
||||||
|
"""
|
||||||
|
if not self._stats[graph_id].nodes.get(invocation_type):
|
||||||
|
self._stats[graph_id].nodes[invocation_type] = NodeStats()
|
||||||
|
stats = self._stats[graph_id].nodes[invocation_type]
|
||||||
|
stats.calls += 1
|
||||||
|
stats.time_used += time_used
|
||||||
|
stats.max_vram = max(stats.max_vram, vram_used)
|
||||||
|
|
||||||
|
def log_stats(self):
|
||||||
|
"""
|
||||||
|
Send the statistics to the system logger at the info level.
|
||||||
|
Stats will only be printed if when the execution of the graph
|
||||||
|
is complete.
|
||||||
|
"""
|
||||||
|
completed = set()
|
||||||
|
for graph_id, node_log in self._stats.items():
|
||||||
|
current_graph_state = self.graph_execution_manager.get(graph_id)
|
||||||
|
if not current_graph_state.is_complete():
|
||||||
|
continue
|
||||||
|
|
||||||
|
total_time = 0
|
||||||
|
logger.info(f"Graph stats: {graph_id}")
|
||||||
|
logger.info("Node Calls Seconds VRAM Used")
|
||||||
|
for node_type, stats in self._stats[graph_id].nodes.items():
|
||||||
|
logger.info(f"{node_type:<20} {stats.calls:>5} {stats.time_used:7.3f}s {stats.max_vram:4.2f}G")
|
||||||
|
total_time += stats.time_used
|
||||||
|
|
||||||
|
logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:7.3f}s")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
logger.info("Current VRAM utilization " + "%4.2fG" % (torch.cuda.memory_allocated() / 1e9))
|
||||||
|
|
||||||
|
completed.add(graph_id)
|
||||||
|
|
||||||
|
for graph_id in completed:
|
||||||
|
del self._stats[graph_id]
|
||||||
@@ -3,9 +3,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from logging import Logger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from typing import Optional, Union, Callable, List, Tuple, TYPE_CHECKING
|
from typing import Literal, Optional, Union, Callable, List, Tuple, TYPE_CHECKING
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
|
|
||||||
from invokeai.backend.model_management import (
|
from invokeai.backend.model_management import (
|
||||||
@@ -193,7 +194,7 @@ class ModelManagerServiceBase(ABC):
|
|||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: Union[ModelType.Main, ModelType.Vae],
|
model_type: Literal[ModelType.Main, ModelType.Vae],
|
||||||
) -> AddModelResult:
|
) -> AddModelResult:
|
||||||
"""
|
"""
|
||||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||||
@@ -292,7 +293,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: InvokeAIAppConfig,
|
config: InvokeAIAppConfig,
|
||||||
logger: ModuleType,
|
logger: Logger,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize with the path to the models.yaml config file.
|
Initialize with the path to the models.yaml config file.
|
||||||
@@ -396,7 +397,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
model_type,
|
model_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> Union[dict, None]:
|
||||||
"""
|
"""
|
||||||
Given a model name returns a dict-like (OmegaConf) object describing it.
|
Given a model name returns a dict-like (OmegaConf) object describing it.
|
||||||
"""
|
"""
|
||||||
@@ -416,7 +417,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
"""
|
"""
|
||||||
return self.mgr.list_models(base_model, model_type)
|
return self.mgr.list_models(base_model, model_type)
|
||||||
|
|
||||||
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> Union[dict, None]:
|
||||||
"""
|
"""
|
||||||
Return information about the model using the same format as list_models()
|
Return information about the model using the same format as list_models()
|
||||||
"""
|
"""
|
||||||
@@ -429,7 +430,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
model_attributes: dict,
|
model_attributes: dict,
|
||||||
clobber: bool = False,
|
clobber: bool = False,
|
||||||
) -> None:
|
) -> AddModelResult:
|
||||||
"""
|
"""
|
||||||
Update the named model with a dictionary of attributes. Will fail with an
|
Update the named model with a dictionary of attributes. Will fail with an
|
||||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||||
@@ -478,7 +479,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: Union[ModelType.Main, ModelType.Vae],
|
model_type: Literal[ModelType.Main, ModelType.Vae],
|
||||||
convert_dest_directory: Optional[Path] = Field(
|
convert_dest_directory: Optional[Path] = Field(
|
||||||
default=None, description="Optional directory location for merged model"
|
default=None, description="Optional directory location for merged model"
|
||||||
),
|
),
|
||||||
@@ -573,9 +574,9 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
default=None, description="Base model shared by all models to be merged"
|
default=None, description="Base model shared by all models to be merged"
|
||||||
),
|
),
|
||||||
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
||||||
alpha: Optional[float] = 0.5,
|
alpha: float = 0.5,
|
||||||
interp: Optional[MergeInterpolationMethod] = None,
|
interp: Optional[MergeInterpolationMethod] = None,
|
||||||
force: Optional[bool] = False,
|
force: bool = False,
|
||||||
merge_dest_directory: Optional[Path] = Field(
|
merge_dest_directory: Optional[Path] = Field(
|
||||||
default=None, description="Optional directory location for merged model"
|
default=None, description="Optional directory location for merged model"
|
||||||
),
|
),
|
||||||
@@ -633,8 +634,8 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
new_name: str = None,
|
new_name: Optional[str] = None,
|
||||||
new_base: BaseModelType = None,
|
new_base: Optional[BaseModelType] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Rename the indicated model. Can provide a new name and/or a new base.
|
Rename the indicated model. Can provide a new name and/or a new base.
|
||||||
|
|||||||
8
invokeai/app/services/models/board_image.py
Normal file
8
invokeai/app/services/models/board_image.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||||
|
|
||||||
|
|
||||||
|
class BoardImage(BaseModelExcludeNull):
|
||||||
|
board_id: str = Field(description="The id of the board")
|
||||||
|
image_name: str = Field(description="The name of the image")
|
||||||
@@ -1,10 +1,11 @@
|
|||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr
|
from pydantic import Field
|
||||||
from invokeai.app.util.misc import get_iso_timestamp
|
from invokeai.app.util.misc import get_iso_timestamp
|
||||||
|
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||||
|
|
||||||
|
|
||||||
class BoardRecord(BaseModel):
|
class BoardRecord(BaseModelExcludeNull):
|
||||||
"""Deserialized board record."""
|
"""Deserialized board record."""
|
||||||
|
|
||||||
board_id: str = Field(description="The unique ID of the board.")
|
board_id: str = Field(description="The unique ID of the board.")
|
||||||
|
|||||||
@@ -1,13 +1,14 @@
|
|||||||
import datetime
|
import datetime
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr
|
from pydantic import Extra, Field, StrictBool, StrictStr
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||||
from invokeai.app.util.misc import get_iso_timestamp
|
from invokeai.app.util.misc import get_iso_timestamp
|
||||||
|
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||||
|
|
||||||
|
|
||||||
class ImageRecord(BaseModel):
|
class ImageRecord(BaseModelExcludeNull):
|
||||||
"""Deserialized image record without metadata."""
|
"""Deserialized image record without metadata."""
|
||||||
|
|
||||||
image_name: str = Field(description="The unique name of the image.")
|
image_name: str = Field(description="The unique name of the image.")
|
||||||
@@ -40,7 +41,7 @@ class ImageRecord(BaseModel):
|
|||||||
"""The node ID that generated this image, if it is a generated image."""
|
"""The node ID that generated this image, if it is a generated image."""
|
||||||
|
|
||||||
|
|
||||||
class ImageRecordChanges(BaseModel, extra=Extra.forbid):
|
class ImageRecordChanges(BaseModelExcludeNull, extra=Extra.forbid):
|
||||||
"""A set of changes to apply to an image record.
|
"""A set of changes to apply to an image record.
|
||||||
|
|
||||||
Only limited changes are valid:
|
Only limited changes are valid:
|
||||||
@@ -60,7 +61,7 @@ class ImageRecordChanges(BaseModel, extra=Extra.forbid):
|
|||||||
"""The image's new `is_intermediate` flag."""
|
"""The image's new `is_intermediate` flag."""
|
||||||
|
|
||||||
|
|
||||||
class ImageUrlsDTO(BaseModel):
|
class ImageUrlsDTO(BaseModelExcludeNull):
|
||||||
"""The URLs for an image and its thumbnail."""
|
"""The URLs for an image and its thumbnail."""
|
||||||
|
|
||||||
image_name: str = Field(description="The unique name of the image.")
|
image_name: str = Field(description="The unique name of the image.")
|
||||||
@@ -76,11 +77,15 @@ class ImageDTO(ImageRecord, ImageUrlsDTO):
|
|||||||
|
|
||||||
board_id: Optional[str] = Field(description="The id of the board the image belongs to, if one exists.")
|
board_id: Optional[str] = Field(description="The id of the board the image belongs to, if one exists.")
|
||||||
"""The id of the board the image belongs to, if one exists."""
|
"""The id of the board the image belongs to, if one exists."""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def image_record_to_dto(
|
def image_record_to_dto(
|
||||||
image_record: ImageRecord, image_url: str, thumbnail_url: str, board_id: Optional[str]
|
image_record: ImageRecord,
|
||||||
|
image_url: str,
|
||||||
|
thumbnail_url: str,
|
||||||
|
board_id: Optional[str],
|
||||||
) -> ImageDTO:
|
) -> ImageDTO:
|
||||||
"""Converts an image record to an image DTO."""
|
"""Converts an image record to an image DTO."""
|
||||||
return ImageDTO(
|
return ImageDTO(
|
||||||
|
|||||||
@@ -1,14 +1,15 @@
|
|||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from threading import Event, Thread, BoundedSemaphore
|
from threading import BoundedSemaphore, Event, Thread
|
||||||
|
|
||||||
from ..invocations.baseinvocation import InvocationContext
|
|
||||||
from .invocation_queue import InvocationQueueItem
|
|
||||||
from .invoker import InvocationProcessorABC, Invoker
|
|
||||||
from ..models.exceptions import CanceledException
|
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
|
|
||||||
|
from ..invocations.baseinvocation import InvocationContext
|
||||||
|
from ..models.exceptions import CanceledException
|
||||||
|
from .invocation_queue import InvocationQueueItem
|
||||||
|
from .invocation_stats import InvocationStatsServiceBase
|
||||||
|
from .invoker import InvocationProcessorABC, Invoker
|
||||||
|
|
||||||
|
|
||||||
class DefaultInvocationProcessor(InvocationProcessorABC):
|
class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||||
__invoker_thread: Thread
|
__invoker_thread: Thread
|
||||||
@@ -35,6 +36,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
def __process(self, stop_event: Event):
|
def __process(self, stop_event: Event):
|
||||||
try:
|
try:
|
||||||
self.__threadLimit.acquire()
|
self.__threadLimit.acquire()
|
||||||
|
statistics: InvocationStatsServiceBase = self.__invoker.services.performance_statistics
|
||||||
|
|
||||||
while not stop_event.is_set():
|
while not stop_event.is_set():
|
||||||
try:
|
try:
|
||||||
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
|
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
|
||||||
@@ -83,35 +86,38 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
|
|
||||||
# Invoke
|
# Invoke
|
||||||
try:
|
try:
|
||||||
outputs = invocation.invoke(
|
with statistics.collect_stats(invocation, graph_execution_state.id):
|
||||||
InvocationContext(
|
outputs = invocation.invoke(
|
||||||
services=self.__invoker.services,
|
InvocationContext(
|
||||||
graph_execution_state_id=graph_execution_state.id,
|
services=self.__invoker.services,
|
||||||
|
graph_execution_state_id=graph_execution_state.id,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# Check queue to see if this is canceled, and skip if so
|
# Check queue to see if this is canceled, and skip if so
|
||||||
if self.__invoker.services.queue.is_canceled(graph_execution_state.id):
|
if self.__invoker.services.queue.is_canceled(graph_execution_state.id):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Save outputs and history
|
# Save outputs and history
|
||||||
graph_execution_state.complete(invocation.id, outputs)
|
graph_execution_state.complete(invocation.id, outputs)
|
||||||
|
|
||||||
# Save the state changes
|
# Save the state changes
|
||||||
self.__invoker.services.graph_execution_manager.set(graph_execution_state)
|
self.__invoker.services.graph_execution_manager.set(graph_execution_state)
|
||||||
|
|
||||||
# Send complete event
|
# Send complete event
|
||||||
self.__invoker.services.events.emit_invocation_complete(
|
self.__invoker.services.events.emit_invocation_complete(
|
||||||
graph_execution_state_id=graph_execution_state.id,
|
graph_execution_state_id=graph_execution_state.id,
|
||||||
node=invocation.dict(),
|
node=invocation.dict(),
|
||||||
source_node_id=source_node_id,
|
source_node_id=source_node_id,
|
||||||
result=outputs.dict(),
|
result=outputs.dict(),
|
||||||
)
|
)
|
||||||
|
statistics.log_stats()
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
except CanceledException:
|
except CanceledException:
|
||||||
|
statistics.reset_stats(graph_execution_state.id)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -133,7 +139,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
error_type=e.__class__.__name__,
|
error_type=e.__class__.__name__,
|
||||||
error=error,
|
error=error,
|
||||||
)
|
)
|
||||||
|
statistics.reset_stats(graph_execution_state.id)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Check queue to see if this is canceled, and skip if so
|
# Check queue to see if this is canceled, and skip if so
|
||||||
|
|||||||
@@ -20,6 +20,6 @@ class LocalUrlService(UrlServiceBase):
|
|||||||
|
|
||||||
# These paths are determined by the routes in invokeai/app/api/routers/images.py
|
# These paths are determined by the routes in invokeai/app/api/routers/images.py
|
||||||
if thumbnail:
|
if thumbnail:
|
||||||
return f"{self._base_url}/images/{image_basename}/thumbnail"
|
return f"{self._base_url}/images/i/{image_basename}/thumbnail"
|
||||||
|
|
||||||
return f"{self._base_url}/images/{image_basename}/full"
|
return f"{self._base_url}/images/i/{image_basename}/full"
|
||||||
|
|||||||
@@ -18,5 +18,5 @@ SEED_MAX = np.iinfo(np.uint32).max
|
|||||||
|
|
||||||
|
|
||||||
def get_random_seed():
|
def get_random_seed():
|
||||||
rng = np.random.default_rng(seed=0)
|
rng = np.random.default_rng(seed=None)
|
||||||
return int(rng.integers(0, SEED_MAX))
|
return int(rng.integers(0, SEED_MAX))
|
||||||
|
|||||||
23
invokeai/app/util/model_exclude_null.py
Normal file
23
invokeai/app/util/model_exclude_null.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
from typing import Any
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
We want to exclude null values from objects that make their way to the client.
|
||||||
|
|
||||||
|
Unfortunately there is no built-in way to do this in pydantic, so we need to override the default
|
||||||
|
dict method to do this.
|
||||||
|
|
||||||
|
From https://github.com/tiangolo/fastapi/discussions/8882#discussioncomment-5154541
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class BaseModelExcludeNull(BaseModel):
|
||||||
|
def dict(self, *args, **kwargs) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Override the default dict method to exclude None values in the response
|
||||||
|
"""
|
||||||
|
kwargs.pop("exclude_none", None)
|
||||||
|
return super().dict(*args, exclude_none=True, **kwargs)
|
||||||
|
|
||||||
|
pass
|
||||||
@@ -12,16 +12,17 @@ def check_invokeai_root(config: InvokeAIAppConfig):
|
|||||||
assert config.model_conf_path.exists(), f"{config.model_conf_path} not found"
|
assert config.model_conf_path.exists(), f"{config.model_conf_path} not found"
|
||||||
assert config.db_path.parent.exists(), f"{config.db_path.parent} not found"
|
assert config.db_path.parent.exists(), f"{config.db_path.parent} not found"
|
||||||
assert config.models_path.exists(), f"{config.models_path} not found"
|
assert config.models_path.exists(), f"{config.models_path} not found"
|
||||||
for model in [
|
if not config.ignore_missing_core_models:
|
||||||
"CLIP-ViT-bigG-14-laion2B-39B-b160k",
|
for model in [
|
||||||
"bert-base-uncased",
|
"CLIP-ViT-bigG-14-laion2B-39B-b160k",
|
||||||
"clip-vit-large-patch14",
|
"bert-base-uncased",
|
||||||
"sd-vae-ft-mse",
|
"clip-vit-large-patch14",
|
||||||
"stable-diffusion-2-clip",
|
"sd-vae-ft-mse",
|
||||||
"stable-diffusion-safety-checker",
|
"stable-diffusion-2-clip",
|
||||||
]:
|
"stable-diffusion-safety-checker",
|
||||||
path = config.models_path / f"core/convert/{model}"
|
]:
|
||||||
assert path.exists(), f"{path} is missing"
|
path = config.models_path / f"core/convert/{model}"
|
||||||
|
assert path.exists(), f"{path} is missing"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print()
|
print()
|
||||||
print(f"An exception has occurred: {str(e)}")
|
print(f"An exception has occurred: {str(e)}")
|
||||||
@@ -32,5 +33,10 @@ def check_invokeai_root(config: InvokeAIAppConfig):
|
|||||||
print(
|
print(
|
||||||
'** From the command line, activate the virtual environment and run "invokeai-configure --yes --skip-sd-weights" **'
|
'** From the command line, activate the virtual environment and run "invokeai-configure --yes --skip-sd-weights" **'
|
||||||
)
|
)
|
||||||
|
print(
|
||||||
|
'** (To skip this check completely, add "--ignore_missing_core_models" to your CLI args. Not installing '
|
||||||
|
"these core models will prevent the loading of some or all .safetensors and .ckpt files. However, you can "
|
||||||
|
"always come back and install these core models in the future.)"
|
||||||
|
)
|
||||||
input("Press any key to continue...")
|
input("Press any key to continue...")
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import requests
|
|||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline
|
||||||
from diffusers import logging as dlogging
|
from diffusers import logging as dlogging
|
||||||
import onnx
|
import onnx
|
||||||
|
import torch
|
||||||
from huggingface_hub import hf_hub_url, HfFolder, HfApi
|
from huggingface_hub import hf_hub_url, HfFolder, HfApi
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@@ -23,6 +24,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
|
|||||||
from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult
|
from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult
|
||||||
from invokeai.backend.model_management.model_probe import ModelProbe, SchedulerPredictionType, ModelProbeInfo
|
from invokeai.backend.model_management.model_probe import ModelProbe, SchedulerPredictionType, ModelProbeInfo
|
||||||
from invokeai.backend.util import download_with_resume
|
from invokeai.backend.util import download_with_resume
|
||||||
|
from invokeai.backend.util.devices import torch_dtype, choose_torch_device
|
||||||
from ..util.logging import InvokeAILogger
|
from ..util.logging import InvokeAILogger
|
||||||
|
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
@@ -99,9 +101,9 @@ class ModelInstall(object):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: InvokeAIAppConfig,
|
config: InvokeAIAppConfig,
|
||||||
prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None,
|
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
||||||
model_manager: ModelManager = None,
|
model_manager: Optional[ModelManager] = None,
|
||||||
access_token: str = None,
|
access_token: Optional[str] = None,
|
||||||
):
|
):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.mgr = model_manager or ModelManager(config.model_conf_path)
|
self.mgr = model_manager or ModelManager(config.model_conf_path)
|
||||||
@@ -303,7 +305,7 @@ class ModelInstall(object):
|
|||||||
|
|
||||||
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
||||||
staging = Path(staging)
|
staging = Path(staging)
|
||||||
if "model_index.json" in files and "unet/model.onnx" not in files:
|
if "model_index.json" in files:
|
||||||
location = self._download_hf_pipeline(repo_id, staging) # pipeline
|
location = self._download_hf_pipeline(repo_id, staging) # pipeline
|
||||||
elif "unet/model.onnx" in files:
|
elif "unet/model.onnx" in files:
|
||||||
location = self._download_hf_model(repo_id, files, staging)
|
location = self._download_hf_model(repo_id, files, staging)
|
||||||
@@ -416,15 +418,25 @@ class ModelInstall(object):
|
|||||||
does a save_pretrained() to the indicated staging area.
|
does a save_pretrained() to the indicated staging area.
|
||||||
"""
|
"""
|
||||||
_, name = repo_id.split("/")
|
_, name = repo_id.split("/")
|
||||||
revisions = ["fp16", "main"] if self.config.precision == "float16" else ["main"]
|
precision = torch_dtype(choose_torch_device())
|
||||||
|
variants = ["fp16", None] if precision == torch.float16 else [None, "fp16"]
|
||||||
|
|
||||||
model = None
|
model = None
|
||||||
for revision in revisions:
|
for variant in variants:
|
||||||
try:
|
try:
|
||||||
model = DiffusionPipeline.from_pretrained(repo_id, revision=revision, safety_checker=None)
|
model = DiffusionPipeline.from_pretrained(
|
||||||
except: # most errors are due to fp16 not being present. Fix this to catch other errors
|
repo_id,
|
||||||
pass
|
variant=variant,
|
||||||
|
torch_dtype=precision,
|
||||||
|
safety_checker=None,
|
||||||
|
)
|
||||||
|
except Exception as e: # most errors are due to fp16 not being present. Fix this to catch other errors
|
||||||
|
if "fp16" not in str(e):
|
||||||
|
print(e)
|
||||||
|
|
||||||
if model:
|
if model:
|
||||||
break
|
break
|
||||||
|
|
||||||
if not model:
|
if not model:
|
||||||
logger.error(f"Diffusers model {repo_id} could not be downloaded. Skipping.")
|
logger.error(f"Diffusers model {repo_id} could not be downloaded. Skipping.")
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -13,3 +13,4 @@ from .models import (
|
|||||||
DuplicateModelException,
|
DuplicateModelException,
|
||||||
)
|
)
|
||||||
from .model_merge import ModelMerger, MergeInterpolationMethod
|
from .model_merge import ModelMerger, MergeInterpolationMethod
|
||||||
|
from .lora import ModelPatcher
|
||||||
|
|||||||
@@ -20,424 +20,6 @@ from diffusers.models import UNet2DConditionModel
|
|||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
# TODO: rename and split this file
|
|
||||||
|
|
||||||
|
|
||||||
class LoRALayerBase:
|
|
||||||
# rank: Optional[int]
|
|
||||||
# alpha: Optional[float]
|
|
||||||
# bias: Optional[torch.Tensor]
|
|
||||||
# layer_key: str
|
|
||||||
|
|
||||||
# @property
|
|
||||||
# def scale(self):
|
|
||||||
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
layer_key: str,
|
|
||||||
values: dict,
|
|
||||||
):
|
|
||||||
if "alpha" in values:
|
|
||||||
self.alpha = values["alpha"].item()
|
|
||||||
else:
|
|
||||||
self.alpha = None
|
|
||||||
|
|
||||||
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
|
|
||||||
self.bias = torch.sparse_coo_tensor(
|
|
||||||
values["bias_indices"],
|
|
||||||
values["bias_values"],
|
|
||||||
tuple(values["bias_size"]),
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
self.bias = None
|
|
||||||
|
|
||||||
self.rank = None # set in layer implementation
|
|
||||||
self.layer_key = layer_key
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
module: torch.nn.Module,
|
|
||||||
input_h: Any, # for real looks like Tuple[torch.nn.Tensor] but not sure
|
|
||||||
multiplier: float,
|
|
||||||
):
|
|
||||||
if type(module) == torch.nn.Conv2d:
|
|
||||||
op = torch.nn.functional.conv2d
|
|
||||||
extra_args = dict(
|
|
||||||
stride=module.stride,
|
|
||||||
padding=module.padding,
|
|
||||||
dilation=module.dilation,
|
|
||||||
groups=module.groups,
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
op = torch.nn.functional.linear
|
|
||||||
extra_args = {}
|
|
||||||
|
|
||||||
weight = self.get_weight()
|
|
||||||
|
|
||||||
bias = self.bias if self.bias is not None else 0
|
|
||||||
scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
|
||||||
return (
|
|
||||||
op(
|
|
||||||
*input_h,
|
|
||||||
(weight + bias).view(module.weight.shape),
|
|
||||||
None,
|
|
||||||
**extra_args,
|
|
||||||
)
|
|
||||||
* multiplier
|
|
||||||
* scale
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_weight(self):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
|
||||||
model_size = 0
|
|
||||||
for val in [self.bias]:
|
|
||||||
if val is not None:
|
|
||||||
model_size += val.nelement() * val.element_size()
|
|
||||||
return model_size
|
|
||||||
|
|
||||||
def to(
|
|
||||||
self,
|
|
||||||
device: Optional[torch.device] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
):
|
|
||||||
if self.bias is not None:
|
|
||||||
self.bias = self.bias.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: find and debug lora/locon with bias
|
|
||||||
class LoRALayer(LoRALayerBase):
|
|
||||||
# up: torch.Tensor
|
|
||||||
# mid: Optional[torch.Tensor]
|
|
||||||
# down: torch.Tensor
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
layer_key: str,
|
|
||||||
values: dict,
|
|
||||||
):
|
|
||||||
super().__init__(layer_key, values)
|
|
||||||
|
|
||||||
self.up = values["lora_up.weight"]
|
|
||||||
self.down = values["lora_down.weight"]
|
|
||||||
if "lora_mid.weight" in values:
|
|
||||||
self.mid = values["lora_mid.weight"]
|
|
||||||
else:
|
|
||||||
self.mid = None
|
|
||||||
|
|
||||||
self.rank = self.down.shape[0]
|
|
||||||
|
|
||||||
def get_weight(self):
|
|
||||||
if self.mid is not None:
|
|
||||||
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
|
||||||
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
|
||||||
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
|
|
||||||
else:
|
|
||||||
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
|
|
||||||
|
|
||||||
return weight
|
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
|
||||||
model_size = super().calc_size()
|
|
||||||
for val in [self.up, self.mid, self.down]:
|
|
||||||
if val is not None:
|
|
||||||
model_size += val.nelement() * val.element_size()
|
|
||||||
return model_size
|
|
||||||
|
|
||||||
def to(
|
|
||||||
self,
|
|
||||||
device: Optional[torch.device] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
):
|
|
||||||
super().to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
self.up = self.up.to(device=device, dtype=dtype)
|
|
||||||
self.down = self.down.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
if self.mid is not None:
|
|
||||||
self.mid = self.mid.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class LoHALayer(LoRALayerBase):
|
|
||||||
# w1_a: torch.Tensor
|
|
||||||
# w1_b: torch.Tensor
|
|
||||||
# w2_a: torch.Tensor
|
|
||||||
# w2_b: torch.Tensor
|
|
||||||
# t1: Optional[torch.Tensor] = None
|
|
||||||
# t2: Optional[torch.Tensor] = None
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
layer_key: str,
|
|
||||||
values: dict,
|
|
||||||
):
|
|
||||||
super().__init__(layer_key, values)
|
|
||||||
|
|
||||||
self.w1_a = values["hada_w1_a"]
|
|
||||||
self.w1_b = values["hada_w1_b"]
|
|
||||||
self.w2_a = values["hada_w2_a"]
|
|
||||||
self.w2_b = values["hada_w2_b"]
|
|
||||||
|
|
||||||
if "hada_t1" in values:
|
|
||||||
self.t1 = values["hada_t1"]
|
|
||||||
else:
|
|
||||||
self.t1 = None
|
|
||||||
|
|
||||||
if "hada_t2" in values:
|
|
||||||
self.t2 = values["hada_t2"]
|
|
||||||
else:
|
|
||||||
self.t2 = None
|
|
||||||
|
|
||||||
self.rank = self.w1_b.shape[0]
|
|
||||||
|
|
||||||
def get_weight(self):
|
|
||||||
if self.t1 is None:
|
|
||||||
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
|
||||||
|
|
||||||
else:
|
|
||||||
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
|
|
||||||
rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a)
|
|
||||||
weight = rebuild1 * rebuild2
|
|
||||||
|
|
||||||
return weight
|
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
|
||||||
model_size = super().calc_size()
|
|
||||||
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
|
|
||||||
if val is not None:
|
|
||||||
model_size += val.nelement() * val.element_size()
|
|
||||||
return model_size
|
|
||||||
|
|
||||||
def to(
|
|
||||||
self,
|
|
||||||
device: Optional[torch.device] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
):
|
|
||||||
super().to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
|
||||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
|
||||||
if self.t1 is not None:
|
|
||||||
self.t1 = self.t1.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
|
||||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
|
||||||
if self.t2 is not None:
|
|
||||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class LoKRLayer(LoRALayerBase):
|
|
||||||
# w1: Optional[torch.Tensor] = None
|
|
||||||
# w1_a: Optional[torch.Tensor] = None
|
|
||||||
# w1_b: Optional[torch.Tensor] = None
|
|
||||||
# w2: Optional[torch.Tensor] = None
|
|
||||||
# w2_a: Optional[torch.Tensor] = None
|
|
||||||
# w2_b: Optional[torch.Tensor] = None
|
|
||||||
# t2: Optional[torch.Tensor] = None
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
layer_key: str,
|
|
||||||
values: dict,
|
|
||||||
):
|
|
||||||
super().__init__(layer_key, values)
|
|
||||||
|
|
||||||
if "lokr_w1" in values:
|
|
||||||
self.w1 = values["lokr_w1"]
|
|
||||||
self.w1_a = None
|
|
||||||
self.w1_b = None
|
|
||||||
else:
|
|
||||||
self.w1 = None
|
|
||||||
self.w1_a = values["lokr_w1_a"]
|
|
||||||
self.w1_b = values["lokr_w1_b"]
|
|
||||||
|
|
||||||
if "lokr_w2" in values:
|
|
||||||
self.w2 = values["lokr_w2"]
|
|
||||||
self.w2_a = None
|
|
||||||
self.w2_b = None
|
|
||||||
else:
|
|
||||||
self.w2 = None
|
|
||||||
self.w2_a = values["lokr_w2_a"]
|
|
||||||
self.w2_b = values["lokr_w2_b"]
|
|
||||||
|
|
||||||
if "lokr_t2" in values:
|
|
||||||
self.t2 = values["lokr_t2"]
|
|
||||||
else:
|
|
||||||
self.t2 = None
|
|
||||||
|
|
||||||
if "lokr_w1_b" in values:
|
|
||||||
self.rank = values["lokr_w1_b"].shape[0]
|
|
||||||
elif "lokr_w2_b" in values:
|
|
||||||
self.rank = values["lokr_w2_b"].shape[0]
|
|
||||||
else:
|
|
||||||
self.rank = None # unscaled
|
|
||||||
|
|
||||||
def get_weight(self):
|
|
||||||
w1 = self.w1
|
|
||||||
if w1 is None:
|
|
||||||
w1 = self.w1_a @ self.w1_b
|
|
||||||
|
|
||||||
w2 = self.w2
|
|
||||||
if w2 is None:
|
|
||||||
if self.t2 is None:
|
|
||||||
w2 = self.w2_a @ self.w2_b
|
|
||||||
else:
|
|
||||||
w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b)
|
|
||||||
|
|
||||||
if len(w2.shape) == 4:
|
|
||||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
|
||||||
w2 = w2.contiguous()
|
|
||||||
weight = torch.kron(w1, w2)
|
|
||||||
|
|
||||||
return weight
|
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
|
||||||
model_size = super().calc_size()
|
|
||||||
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
|
|
||||||
if val is not None:
|
|
||||||
model_size += val.nelement() * val.element_size()
|
|
||||||
return model_size
|
|
||||||
|
|
||||||
def to(
|
|
||||||
self,
|
|
||||||
device: Optional[torch.device] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
):
|
|
||||||
super().to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
if self.w1 is not None:
|
|
||||||
self.w1 = self.w1.to(device=device, dtype=dtype)
|
|
||||||
else:
|
|
||||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
|
||||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
if self.w2 is not None:
|
|
||||||
self.w2 = self.w2.to(device=device, dtype=dtype)
|
|
||||||
else:
|
|
||||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
|
||||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
if self.t2 is not None:
|
|
||||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class LoRAModel: # (torch.nn.Module):
|
|
||||||
_name: str
|
|
||||||
layers: Dict[str, LoRALayer]
|
|
||||||
_device: torch.device
|
|
||||||
_dtype: torch.dtype
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
layers: Dict[str, LoRALayer],
|
|
||||||
device: torch.device,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
):
|
|
||||||
self._name = name
|
|
||||||
self._device = device or torch.cpu
|
|
||||||
self._dtype = dtype or torch.float32
|
|
||||||
self.layers = layers
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self):
|
|
||||||
return self._name
|
|
||||||
|
|
||||||
@property
|
|
||||||
def device(self):
|
|
||||||
return self._device
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dtype(self):
|
|
||||||
return self._dtype
|
|
||||||
|
|
||||||
def to(
|
|
||||||
self,
|
|
||||||
device: Optional[torch.device] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
) -> LoRAModel:
|
|
||||||
# TODO: try revert if exception?
|
|
||||||
for key, layer in self.layers.items():
|
|
||||||
layer.to(device=device, dtype=dtype)
|
|
||||||
self._device = device
|
|
||||||
self._dtype = dtype
|
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
|
||||||
model_size = 0
|
|
||||||
for _, layer in self.layers.items():
|
|
||||||
model_size += layer.calc_size()
|
|
||||||
return model_size
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_checkpoint(
|
|
||||||
cls,
|
|
||||||
file_path: Union[str, Path],
|
|
||||||
device: Optional[torch.device] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
):
|
|
||||||
device = device or torch.device("cpu")
|
|
||||||
dtype = dtype or torch.float32
|
|
||||||
|
|
||||||
if isinstance(file_path, str):
|
|
||||||
file_path = Path(file_path)
|
|
||||||
|
|
||||||
model = cls(
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
name=file_path.stem, # TODO:
|
|
||||||
layers=dict(),
|
|
||||||
)
|
|
||||||
|
|
||||||
if file_path.suffix == ".safetensors":
|
|
||||||
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
|
|
||||||
else:
|
|
||||||
state_dict = torch.load(file_path, map_location="cpu")
|
|
||||||
|
|
||||||
state_dict = cls._group_state(state_dict)
|
|
||||||
|
|
||||||
for layer_key, values in state_dict.items():
|
|
||||||
# lora and locon
|
|
||||||
if "lora_down.weight" in values:
|
|
||||||
layer = LoRALayer(layer_key, values)
|
|
||||||
|
|
||||||
# loha
|
|
||||||
elif "hada_w1_b" in values:
|
|
||||||
layer = LoHALayer(layer_key, values)
|
|
||||||
|
|
||||||
# lokr
|
|
||||||
elif "lokr_w1_b" in values or "lokr_w1" in values:
|
|
||||||
layer = LoKRLayer(layer_key, values)
|
|
||||||
|
|
||||||
else:
|
|
||||||
# TODO: diff/ia3/... format
|
|
||||||
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key}")
|
|
||||||
return
|
|
||||||
|
|
||||||
# lower memory consumption by removing already parsed layer values
|
|
||||||
state_dict[layer_key].clear()
|
|
||||||
|
|
||||||
layer.to(device=device, dtype=dtype)
|
|
||||||
model.layers[layer_key] = layer
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _group_state(state_dict: dict):
|
|
||||||
state_dict_groupped = dict()
|
|
||||||
|
|
||||||
for key, value in state_dict.items():
|
|
||||||
stem, leaf = key.split(".", 1)
|
|
||||||
if stem not in state_dict_groupped:
|
|
||||||
state_dict_groupped[stem] = dict()
|
|
||||||
state_dict_groupped[stem][leaf] = value
|
|
||||||
|
|
||||||
return state_dict_groupped
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
loras = [
|
loras = [
|
||||||
(lora_model1, 0.7),
|
(lora_model1, 0.7),
|
||||||
@@ -516,6 +98,26 @@ class ModelPatcher:
|
|||||||
with cls.apply_lora(text_encoder, loras, "lora_te_"):
|
with cls.apply_lora(text_encoder, loras, "lora_te_"):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@contextmanager
|
||||||
|
def apply_sdxl_lora_text_encoder(
|
||||||
|
cls,
|
||||||
|
text_encoder: CLIPTextModel,
|
||||||
|
loras: List[Tuple[LoRAModel, float]],
|
||||||
|
):
|
||||||
|
with cls.apply_lora(text_encoder, loras, "lora_te1_"):
|
||||||
|
yield
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@contextmanager
|
||||||
|
def apply_sdxl_lora_text_encoder2(
|
||||||
|
cls,
|
||||||
|
text_encoder: CLIPTextModel,
|
||||||
|
loras: List[Tuple[LoRAModel, float]],
|
||||||
|
):
|
||||||
|
with cls.apply_lora(text_encoder, loras, "lora_te2_"):
|
||||||
|
yield
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def apply_lora(
|
def apply_lora(
|
||||||
@@ -562,7 +164,7 @@ class ModelPatcher:
|
|||||||
cls,
|
cls,
|
||||||
tokenizer: CLIPTokenizer,
|
tokenizer: CLIPTokenizer,
|
||||||
text_encoder: CLIPTextModel,
|
text_encoder: CLIPTextModel,
|
||||||
ti_list: List[Any],
|
ti_list: List[Tuple[str, Any]],
|
||||||
) -> Tuple[CLIPTokenizer, TextualInversionManager]:
|
) -> Tuple[CLIPTokenizer, TextualInversionManager]:
|
||||||
init_tokens_count = None
|
init_tokens_count = None
|
||||||
new_tokens_added = None
|
new_tokens_added = None
|
||||||
@@ -572,27 +174,27 @@ class ModelPatcher:
|
|||||||
ti_manager = TextualInversionManager(ti_tokenizer)
|
ti_manager = TextualInversionManager(ti_tokenizer)
|
||||||
init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings
|
init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings
|
||||||
|
|
||||||
def _get_trigger(ti, index):
|
def _get_trigger(ti_name, index):
|
||||||
trigger = ti.name
|
trigger = ti_name
|
||||||
if index > 0:
|
if index > 0:
|
||||||
trigger += f"-!pad-{i}"
|
trigger += f"-!pad-{i}"
|
||||||
return f"<{trigger}>"
|
return f"<{trigger}>"
|
||||||
|
|
||||||
# modify tokenizer
|
# modify tokenizer
|
||||||
new_tokens_added = 0
|
new_tokens_added = 0
|
||||||
for ti in ti_list:
|
for ti_name, ti in ti_list:
|
||||||
for i in range(ti.embedding.shape[0]):
|
for i in range(ti.embedding.shape[0]):
|
||||||
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i))
|
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
|
||||||
|
|
||||||
# modify text_encoder
|
# modify text_encoder
|
||||||
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added)
|
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added)
|
||||||
model_embeddings = text_encoder.get_input_embeddings()
|
model_embeddings = text_encoder.get_input_embeddings()
|
||||||
|
|
||||||
for ti in ti_list:
|
for ti_name, ti in ti_list:
|
||||||
ti_tokens = []
|
ti_tokens = []
|
||||||
for i in range(ti.embedding.shape[0]):
|
for i in range(ti.embedding.shape[0]):
|
||||||
embedding = ti.embedding[i]
|
embedding = ti.embedding[i]
|
||||||
trigger = _get_trigger(ti, i)
|
trigger = _get_trigger(ti_name, i)
|
||||||
|
|
||||||
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
|
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
|
||||||
if token_id == ti_tokenizer.unk_token_id:
|
if token_id == ti_tokenizer.unk_token_id:
|
||||||
@@ -637,7 +239,6 @@ class ModelPatcher:
|
|||||||
|
|
||||||
|
|
||||||
class TextualInversionModel:
|
class TextualInversionModel:
|
||||||
name: str
|
|
||||||
embedding: torch.Tensor # [n, 768]|[n, 1280]
|
embedding: torch.Tensor # [n, 768]|[n, 1280]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -651,7 +252,6 @@ class TextualInversionModel:
|
|||||||
file_path = Path(file_path)
|
file_path = Path(file_path)
|
||||||
|
|
||||||
result = cls() # TODO:
|
result = cls() # TODO:
|
||||||
result.name = file_path.stem # TODO:
|
|
||||||
|
|
||||||
if file_path.suffix == ".safetensors":
|
if file_path.suffix == ".safetensors":
|
||||||
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
|
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
|
||||||
@@ -828,7 +428,7 @@ class ONNXModelPatcher:
|
|||||||
cls,
|
cls,
|
||||||
tokenizer: CLIPTokenizer,
|
tokenizer: CLIPTokenizer,
|
||||||
text_encoder: IAIOnnxRuntimeModel,
|
text_encoder: IAIOnnxRuntimeModel,
|
||||||
ti_list: List[Any],
|
ti_list: List[Tuple[str, Any]],
|
||||||
) -> Tuple[CLIPTokenizer, TextualInversionManager]:
|
) -> Tuple[CLIPTokenizer, TextualInversionManager]:
|
||||||
from .models.base import IAIOnnxRuntimeModel
|
from .models.base import IAIOnnxRuntimeModel
|
||||||
|
|
||||||
@@ -841,17 +441,17 @@ class ONNXModelPatcher:
|
|||||||
ti_tokenizer = copy.deepcopy(tokenizer)
|
ti_tokenizer = copy.deepcopy(tokenizer)
|
||||||
ti_manager = TextualInversionManager(ti_tokenizer)
|
ti_manager = TextualInversionManager(ti_tokenizer)
|
||||||
|
|
||||||
def _get_trigger(ti, index):
|
def _get_trigger(ti_name, index):
|
||||||
trigger = ti.name
|
trigger = ti_name
|
||||||
if index > 0:
|
if index > 0:
|
||||||
trigger += f"-!pad-{i}"
|
trigger += f"-!pad-{i}"
|
||||||
return f"<{trigger}>"
|
return f"<{trigger}>"
|
||||||
|
|
||||||
# modify tokenizer
|
# modify tokenizer
|
||||||
new_tokens_added = 0
|
new_tokens_added = 0
|
||||||
for ti in ti_list:
|
for ti_name, ti in ti_list:
|
||||||
for i in range(ti.embedding.shape[0]):
|
for i in range(ti.embedding.shape[0]):
|
||||||
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i))
|
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
|
||||||
|
|
||||||
# modify text_encoder
|
# modify text_encoder
|
||||||
orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"]
|
orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"]
|
||||||
@@ -861,11 +461,11 @@ class ONNXModelPatcher:
|
|||||||
axis=0,
|
axis=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
for ti in ti_list:
|
for ti_name, ti in ti_list:
|
||||||
ti_tokens = []
|
ti_tokens = []
|
||||||
for i in range(ti.embedding.shape[0]):
|
for i in range(ti.embedding.shape[0]):
|
||||||
embedding = ti.embedding[i].detach().numpy()
|
embedding = ti.embedding[i].detach().numpy()
|
||||||
trigger = _get_trigger(ti, i)
|
trigger = _get_trigger(ti_name, i)
|
||||||
|
|
||||||
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
|
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
|
||||||
if token_id == ti_tokenizer.unk_token_id:
|
if token_id == ti_tokenizer.unk_token_id:
|
||||||
|
|||||||
@@ -28,8 +28,6 @@ import torch
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config import get_invokeai_config
|
|
||||||
from .lora import LoRAModel, TextualInversionModel
|
|
||||||
from .models import BaseModelType, ModelType, SubModelType, ModelBase
|
from .models import BaseModelType, ModelType, SubModelType, ModelBase
|
||||||
|
|
||||||
# Maximum size of the cache, in gigs
|
# Maximum size of the cache, in gigs
|
||||||
@@ -188,7 +186,7 @@ class ModelCache(object):
|
|||||||
cache_entry = self._cached_models.get(key, None)
|
cache_entry = self._cached_models.get(key, None)
|
||||||
if cache_entry is None:
|
if cache_entry is None:
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
f"Loading model {model_path}, type {base_model.value}:{model_type.value}:{submodel.value if submodel else ''}"
|
f"Loading model {model_path}, type {base_model.value}:{model_type.value}{':'+submodel.value if submodel else ''}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# this will remove older cached models until
|
# this will remove older cached models until
|
||||||
|
|||||||
@@ -234,7 +234,7 @@ import textwrap
|
|||||||
import yaml
|
import yaml
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, List, Tuple, Union, Dict, Set, Callable, types
|
from typing import Literal, Optional, List, Tuple, Union, Dict, Set, Callable, types
|
||||||
from shutil import rmtree, move
|
from shutil import rmtree, move
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -472,7 +472,7 @@ class ModelManager(object):
|
|||||||
if submodel_type is not None and hasattr(model_config, submodel_type):
|
if submodel_type is not None and hasattr(model_config, submodel_type):
|
||||||
override_path = getattr(model_config, submodel_type)
|
override_path = getattr(model_config, submodel_type)
|
||||||
if override_path:
|
if override_path:
|
||||||
model_path = self.app_config.root_path / override_path
|
model_path = self.resolve_path(override_path)
|
||||||
model_type = submodel_type
|
model_type = submodel_type
|
||||||
submodel_type = None
|
submodel_type = None
|
||||||
model_class = MODEL_CLASSES[base_model][model_type]
|
model_class = MODEL_CLASSES[base_model][model_type]
|
||||||
@@ -518,7 +518,7 @@ class ModelManager(object):
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
) -> dict:
|
) -> Union[dict, None]:
|
||||||
"""
|
"""
|
||||||
Given a model name returns the OmegaConf (dict-like) object describing it.
|
Given a model name returns the OmegaConf (dict-like) object describing it.
|
||||||
"""
|
"""
|
||||||
@@ -540,13 +540,15 @@ class ModelManager(object):
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
) -> dict:
|
) -> Union[dict, None]:
|
||||||
"""
|
"""
|
||||||
Returns a dict describing one installed model, using
|
Returns a dict describing one installed model, using
|
||||||
the combined format of the list_models() method.
|
the combined format of the list_models() method.
|
||||||
"""
|
"""
|
||||||
models = self.list_models(base_model, model_type, model_name)
|
models = self.list_models(base_model, model_type, model_name)
|
||||||
return models[0] if models else None
|
if len(models) > 1:
|
||||||
|
return models[0]
|
||||||
|
return None
|
||||||
|
|
||||||
def list_models(
|
def list_models(
|
||||||
self,
|
self,
|
||||||
@@ -560,7 +562,7 @@ class ModelManager(object):
|
|||||||
|
|
||||||
model_keys = (
|
model_keys = (
|
||||||
[self.create_key(model_name, base_model, model_type)]
|
[self.create_key(model_name, base_model, model_type)]
|
||||||
if model_name
|
if model_name and base_model and model_type
|
||||||
else sorted(self.models, key=str.casefold)
|
else sorted(self.models, key=str.casefold)
|
||||||
)
|
)
|
||||||
models = []
|
models = []
|
||||||
@@ -596,7 +598,7 @@ class ModelManager(object):
|
|||||||
Print a table of models and their descriptions. This needs to be redone
|
Print a table of models and their descriptions. This needs to be redone
|
||||||
"""
|
"""
|
||||||
# TODO: redo
|
# TODO: redo
|
||||||
for model_type, model_dict in self.list_models().items():
|
for model_dict in self.list_models():
|
||||||
for model_name, model_info in model_dict.items():
|
for model_name, model_info in model_dict.items():
|
||||||
line = f'{model_info["name"]:25s} {model_info["type"]:10s} {model_info["description"]}'
|
line = f'{model_info["name"]:25s} {model_info["type"]:10s} {model_info["description"]}'
|
||||||
print(line)
|
print(line)
|
||||||
@@ -670,7 +672,7 @@ class ModelManager(object):
|
|||||||
# TODO: if path changed and old_model.path inside models folder should we delete this too?
|
# TODO: if path changed and old_model.path inside models folder should we delete this too?
|
||||||
|
|
||||||
# remove conversion cache as config changed
|
# remove conversion cache as config changed
|
||||||
old_model_path = self.app_config.root_path / old_model.path
|
old_model_path = self.resolve_model_path(old_model.path)
|
||||||
old_model_cache = self._get_model_cache_path(old_model_path)
|
old_model_cache = self._get_model_cache_path(old_model_path)
|
||||||
if old_model_cache.exists():
|
if old_model_cache.exists():
|
||||||
if old_model_cache.is_dir():
|
if old_model_cache.is_dir():
|
||||||
@@ -699,8 +701,8 @@ class ModelManager(object):
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
new_name: str = None,
|
new_name: Optional[str] = None,
|
||||||
new_base: BaseModelType = None,
|
new_base: Optional[BaseModelType] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Rename or rebase a model.
|
Rename or rebase a model.
|
||||||
@@ -753,7 +755,7 @@ class ModelManager(object):
|
|||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: Union[ModelType.Main, ModelType.Vae],
|
model_type: Literal[ModelType.Main, ModelType.Vae],
|
||||||
dest_directory: Optional[Path] = None,
|
dest_directory: Optional[Path] = None,
|
||||||
) -> AddModelResult:
|
) -> AddModelResult:
|
||||||
"""
|
"""
|
||||||
@@ -767,6 +769,10 @@ class ModelManager(object):
|
|||||||
This will raise a ValueError unless the model is a checkpoint.
|
This will raise a ValueError unless the model is a checkpoint.
|
||||||
"""
|
"""
|
||||||
info = self.model_info(model_name, base_model, model_type)
|
info = self.model_info(model_name, base_model, model_type)
|
||||||
|
|
||||||
|
if info is None:
|
||||||
|
raise FileNotFoundError(f"model not found: {model_name}")
|
||||||
|
|
||||||
if info["model_format"] != "checkpoint":
|
if info["model_format"] != "checkpoint":
|
||||||
raise ValueError(f"not a checkpoint format model: {model_name}")
|
raise ValueError(f"not a checkpoint format model: {model_name}")
|
||||||
|
|
||||||
@@ -780,7 +786,7 @@ class ModelManager(object):
|
|||||||
model_type,
|
model_type,
|
||||||
**submodel,
|
**submodel,
|
||||||
)
|
)
|
||||||
checkpoint_path = self.app_config.root_path / info["path"]
|
checkpoint_path = self.resolve_model_path(info["path"])
|
||||||
old_diffusers_path = self.resolve_model_path(model.location)
|
old_diffusers_path = self.resolve_model_path(model.location)
|
||||||
new_diffusers_path = (
|
new_diffusers_path = (
|
||||||
dest_directory or self.app_config.models_path / base_model.value / model_type.value
|
dest_directory or self.app_config.models_path / base_model.value / model_type.value
|
||||||
@@ -836,7 +842,7 @@ class ModelManager(object):
|
|||||||
|
|
||||||
return search_folder, found_models
|
return search_folder, found_models
|
||||||
|
|
||||||
def commit(self, conf_file: Path = None) -> None:
|
def commit(self, conf_file: Optional[Path] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Write current configuration out to the indicated file.
|
Write current configuration out to the indicated file.
|
||||||
"""
|
"""
|
||||||
@@ -983,7 +989,7 @@ class ModelManager(object):
|
|||||||
# LS: hacky
|
# LS: hacky
|
||||||
# Patch in the SD VAE from core so that it is available for use by the UI
|
# Patch in the SD VAE from core so that it is available for use by the UI
|
||||||
try:
|
try:
|
||||||
self.heuristic_import({self.resolve_model_path("core/convert/sd-vae-ft-mse")})
|
self.heuristic_import({str(self.resolve_model_path("core/convert/sd-vae-ft-mse"))})
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -992,7 +998,7 @@ class ModelManager(object):
|
|||||||
model_manager=self,
|
model_manager=self,
|
||||||
prediction_type_helper=ask_user_for_prediction_type,
|
prediction_type_helper=ask_user_for_prediction_type,
|
||||||
)
|
)
|
||||||
known_paths = {config.root_path / x["path"] for x in self.list_models()}
|
known_paths = {self.resolve_model_path(x["path"]) for x in self.list_models()}
|
||||||
directories = {
|
directories = {
|
||||||
config.root_path / x
|
config.root_path / x
|
||||||
for x in [
|
for x in [
|
||||||
@@ -1011,7 +1017,7 @@ class ModelManager(object):
|
|||||||
def heuristic_import(
|
def heuristic_import(
|
||||||
self,
|
self,
|
||||||
items_to_import: Set[str],
|
items_to_import: Set[str],
|
||||||
prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None,
|
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
||||||
) -> Dict[str, AddModelResult]:
|
) -> Dict[str, AddModelResult]:
|
||||||
"""Import a list of paths, repo_ids or URLs. Returns the set of
|
"""Import a list of paths, repo_ids or URLs. Returns the set of
|
||||||
successfully imported items.
|
successfully imported items.
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ class ModelMerger(object):
|
|||||||
self,
|
self,
|
||||||
model_paths: List[Path],
|
model_paths: List[Path],
|
||||||
alpha: float = 0.5,
|
alpha: float = 0.5,
|
||||||
interp: MergeInterpolationMethod = None,
|
interp: Optional[MergeInterpolationMethod] = None,
|
||||||
force: bool = False,
|
force: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> DiffusionPipeline:
|
) -> DiffusionPipeline:
|
||||||
@@ -73,7 +73,7 @@ class ModelMerger(object):
|
|||||||
base_model: Union[BaseModelType, str],
|
base_model: Union[BaseModelType, str],
|
||||||
merged_model_name: str,
|
merged_model_name: str,
|
||||||
alpha: float = 0.5,
|
alpha: float = 0.5,
|
||||||
interp: MergeInterpolationMethod = None,
|
interp: Optional[MergeInterpolationMethod] = None,
|
||||||
force: bool = False,
|
force: bool = False,
|
||||||
merge_dest_directory: Optional[Path] = None,
|
merge_dest_directory: Optional[Path] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@@ -122,7 +122,7 @@ class ModelMerger(object):
|
|||||||
dump_path.mkdir(parents=True, exist_ok=True)
|
dump_path.mkdir(parents=True, exist_ok=True)
|
||||||
dump_path = dump_path / merged_model_name
|
dump_path = dump_path / merged_model_name
|
||||||
|
|
||||||
merged_pipe.save_pretrained(dump_path, safe_serialization=1)
|
merged_pipe.save_pretrained(dump_path, safe_serialization=True)
|
||||||
attributes = dict(
|
attributes = dict(
|
||||||
path=str(dump_path),
|
path=str(dump_path),
|
||||||
description=f"Merge of models {', '.join(model_names)}",
|
description=f"Merge of models {', '.join(model_names)}",
|
||||||
|
|||||||
@@ -315,21 +315,38 @@ class LoRACheckpointProbe(CheckpointProbeBase):
|
|||||||
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
checkpoint = self.checkpoint
|
checkpoint = self.checkpoint
|
||||||
|
|
||||||
|
# SD-2 models are very hard to probe. These probes are brittle and likely to fail in the future
|
||||||
|
# There are also some "SD-2 LoRAs" that have identical keys and shapes to SD-1 and will be
|
||||||
|
# misclassified as SD-1
|
||||||
|
key = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
|
||||||
|
if key in checkpoint and checkpoint[key].shape[0] == 320:
|
||||||
|
return BaseModelType.StableDiffusion2
|
||||||
|
|
||||||
|
key = "lora_unet_output_blocks_5_1_transformer_blocks_1_ff_net_2.lora_up.weight"
|
||||||
|
if key in checkpoint:
|
||||||
|
return BaseModelType.StableDiffusionXL
|
||||||
|
|
||||||
key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
|
key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
|
||||||
key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a"
|
key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
|
||||||
|
key3 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a"
|
||||||
|
|
||||||
lora_token_vector_length = (
|
lora_token_vector_length = (
|
||||||
checkpoint[key1].shape[1]
|
checkpoint[key1].shape[1]
|
||||||
if key1 in checkpoint
|
if key1 in checkpoint
|
||||||
else checkpoint[key2].shape[0]
|
else checkpoint[key2].shape[1]
|
||||||
if key2 in checkpoint
|
if key2 in checkpoint
|
||||||
else 768
|
else checkpoint[key3].shape[0]
|
||||||
|
if key3 in checkpoint
|
||||||
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
if lora_token_vector_length == 768:
|
if lora_token_vector_length == 768:
|
||||||
return BaseModelType.StableDiffusion1
|
return BaseModelType.StableDiffusion1
|
||||||
elif lora_token_vector_length == 1024:
|
elif lora_token_vector_length == 1024:
|
||||||
return BaseModelType.StableDiffusion2
|
return BaseModelType.StableDiffusion2
|
||||||
else:
|
else:
|
||||||
return None
|
raise InvalidModelException(f"Unknown LoRA type")
|
||||||
|
|
||||||
|
|
||||||
class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
||||||
|
|||||||
@@ -292,8 +292,9 @@ class DiffusersModel(ModelBase):
|
|||||||
)
|
)
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# print("====ERR LOAD====")
|
if not str(e).startswith("Error no file"):
|
||||||
# print(f"{variant}: {e}")
|
print("====ERR LOAD====")
|
||||||
|
print(f"{variant}: {e}")
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Failed to load {self.base_model}:{self.model_type}:{child_type} model")
|
raise Exception(f"Failed to load {self.base_model}:{self.model_type}:{child_type} model")
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional, Union, Literal
|
from typing import Optional, Dict, Union, Literal, Any
|
||||||
|
from pathlib import Path
|
||||||
|
from safetensors.torch import load_file
|
||||||
from .base import (
|
from .base import (
|
||||||
ModelBase,
|
ModelBase,
|
||||||
ModelConfigBase,
|
ModelConfigBase,
|
||||||
@@ -13,9 +15,6 @@ from .base import (
|
|||||||
ModelNotFoundException,
|
ModelNotFoundException,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: naming
|
|
||||||
from ..lora import LoRAModel as LoRAModelRaw
|
|
||||||
|
|
||||||
|
|
||||||
class LoRAModelFormat(str, Enum):
|
class LoRAModelFormat(str, Enum):
|
||||||
LyCORIS = "lycoris"
|
LyCORIS = "lycoris"
|
||||||
@@ -50,6 +49,7 @@ class LoRAModel(ModelBase):
|
|||||||
model = LoRAModelRaw.from_checkpoint(
|
model = LoRAModelRaw.from_checkpoint(
|
||||||
file_path=self.model_path,
|
file_path=self.model_path,
|
||||||
dtype=torch_dtype,
|
dtype=torch_dtype,
|
||||||
|
base_model=self.base_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model_size = model.calc_size()
|
self.model_size = model.calc_size()
|
||||||
@@ -87,3 +87,582 @@ class LoRAModel(ModelBase):
|
|||||||
raise NotImplementedError("Diffusers lora not supported")
|
raise NotImplementedError("Diffusers lora not supported")
|
||||||
else:
|
else:
|
||||||
return model_path
|
return model_path
|
||||||
|
|
||||||
|
|
||||||
|
class LoRALayerBase:
|
||||||
|
# rank: Optional[int]
|
||||||
|
# alpha: Optional[float]
|
||||||
|
# bias: Optional[torch.Tensor]
|
||||||
|
# layer_key: str
|
||||||
|
|
||||||
|
# @property
|
||||||
|
# def scale(self):
|
||||||
|
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
layer_key: str,
|
||||||
|
values: dict,
|
||||||
|
):
|
||||||
|
if "alpha" in values:
|
||||||
|
self.alpha = values["alpha"].item()
|
||||||
|
else:
|
||||||
|
self.alpha = None
|
||||||
|
|
||||||
|
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
|
||||||
|
self.bias = torch.sparse_coo_tensor(
|
||||||
|
values["bias_indices"],
|
||||||
|
values["bias_values"],
|
||||||
|
tuple(values["bias_size"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.bias = None
|
||||||
|
|
||||||
|
self.rank = None # set in layer implementation
|
||||||
|
self.layer_key = layer_key
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
module: torch.nn.Module,
|
||||||
|
input_h: Any, # for real looks like Tuple[torch.nn.Tensor] but not sure
|
||||||
|
multiplier: float,
|
||||||
|
):
|
||||||
|
if type(module) == torch.nn.Conv2d:
|
||||||
|
op = torch.nn.functional.conv2d
|
||||||
|
extra_args = dict(
|
||||||
|
stride=module.stride,
|
||||||
|
padding=module.padding,
|
||||||
|
dilation=module.dilation,
|
||||||
|
groups=module.groups,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
op = torch.nn.functional.linear
|
||||||
|
extra_args = {}
|
||||||
|
|
||||||
|
weight = self.get_weight()
|
||||||
|
|
||||||
|
bias = self.bias if self.bias is not None else 0
|
||||||
|
scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
||||||
|
return (
|
||||||
|
op(
|
||||||
|
*input_h,
|
||||||
|
(weight + bias).view(module.weight.shape),
|
||||||
|
None,
|
||||||
|
**extra_args,
|
||||||
|
)
|
||||||
|
* multiplier
|
||||||
|
* scale
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_weight(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def calc_size(self) -> int:
|
||||||
|
model_size = 0
|
||||||
|
for val in [self.bias]:
|
||||||
|
if val is not None:
|
||||||
|
model_size += val.nelement() * val.element_size()
|
||||||
|
return model_size
|
||||||
|
|
||||||
|
def to(
|
||||||
|
self,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
):
|
||||||
|
if self.bias is not None:
|
||||||
|
self.bias = self.bias.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: find and debug lora/locon with bias
|
||||||
|
class LoRALayer(LoRALayerBase):
|
||||||
|
# up: torch.Tensor
|
||||||
|
# mid: Optional[torch.Tensor]
|
||||||
|
# down: torch.Tensor
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
layer_key: str,
|
||||||
|
values: dict,
|
||||||
|
):
|
||||||
|
super().__init__(layer_key, values)
|
||||||
|
|
||||||
|
self.up = values["lora_up.weight"]
|
||||||
|
self.down = values["lora_down.weight"]
|
||||||
|
if "lora_mid.weight" in values:
|
||||||
|
self.mid = values["lora_mid.weight"]
|
||||||
|
else:
|
||||||
|
self.mid = None
|
||||||
|
|
||||||
|
self.rank = self.down.shape[0]
|
||||||
|
|
||||||
|
def get_weight(self):
|
||||||
|
if self.mid is not None:
|
||||||
|
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
||||||
|
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
||||||
|
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
|
||||||
|
else:
|
||||||
|
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
|
||||||
|
|
||||||
|
return weight
|
||||||
|
|
||||||
|
def calc_size(self) -> int:
|
||||||
|
model_size = super().calc_size()
|
||||||
|
for val in [self.up, self.mid, self.down]:
|
||||||
|
if val is not None:
|
||||||
|
model_size += val.nelement() * val.element_size()
|
||||||
|
return model_size
|
||||||
|
|
||||||
|
def to(
|
||||||
|
self,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
):
|
||||||
|
super().to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.up = self.up.to(device=device, dtype=dtype)
|
||||||
|
self.down = self.down.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
if self.mid is not None:
|
||||||
|
self.mid = self.mid.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class LoHALayer(LoRALayerBase):
|
||||||
|
# w1_a: torch.Tensor
|
||||||
|
# w1_b: torch.Tensor
|
||||||
|
# w2_a: torch.Tensor
|
||||||
|
# w2_b: torch.Tensor
|
||||||
|
# t1: Optional[torch.Tensor] = None
|
||||||
|
# t2: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
layer_key: str,
|
||||||
|
values: dict,
|
||||||
|
):
|
||||||
|
super().__init__(layer_key, values)
|
||||||
|
|
||||||
|
self.w1_a = values["hada_w1_a"]
|
||||||
|
self.w1_b = values["hada_w1_b"]
|
||||||
|
self.w2_a = values["hada_w2_a"]
|
||||||
|
self.w2_b = values["hada_w2_b"]
|
||||||
|
|
||||||
|
if "hada_t1" in values:
|
||||||
|
self.t1 = values["hada_t1"]
|
||||||
|
else:
|
||||||
|
self.t1 = None
|
||||||
|
|
||||||
|
if "hada_t2" in values:
|
||||||
|
self.t2 = values["hada_t2"]
|
||||||
|
else:
|
||||||
|
self.t2 = None
|
||||||
|
|
||||||
|
self.rank = self.w1_b.shape[0]
|
||||||
|
|
||||||
|
def get_weight(self):
|
||||||
|
if self.t1 is None:
|
||||||
|
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
||||||
|
|
||||||
|
else:
|
||||||
|
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
|
||||||
|
rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a)
|
||||||
|
weight = rebuild1 * rebuild2
|
||||||
|
|
||||||
|
return weight
|
||||||
|
|
||||||
|
def calc_size(self) -> int:
|
||||||
|
model_size = super().calc_size()
|
||||||
|
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
|
||||||
|
if val is not None:
|
||||||
|
model_size += val.nelement() * val.element_size()
|
||||||
|
return model_size
|
||||||
|
|
||||||
|
def to(
|
||||||
|
self,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
):
|
||||||
|
super().to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||||
|
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||||
|
if self.t1 is not None:
|
||||||
|
self.t1 = self.t1.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||||
|
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||||
|
if self.t2 is not None:
|
||||||
|
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class LoKRLayer(LoRALayerBase):
|
||||||
|
# w1: Optional[torch.Tensor] = None
|
||||||
|
# w1_a: Optional[torch.Tensor] = None
|
||||||
|
# w1_b: Optional[torch.Tensor] = None
|
||||||
|
# w2: Optional[torch.Tensor] = None
|
||||||
|
# w2_a: Optional[torch.Tensor] = None
|
||||||
|
# w2_b: Optional[torch.Tensor] = None
|
||||||
|
# t2: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
layer_key: str,
|
||||||
|
values: dict,
|
||||||
|
):
|
||||||
|
super().__init__(layer_key, values)
|
||||||
|
|
||||||
|
if "lokr_w1" in values:
|
||||||
|
self.w1 = values["lokr_w1"]
|
||||||
|
self.w1_a = None
|
||||||
|
self.w1_b = None
|
||||||
|
else:
|
||||||
|
self.w1 = None
|
||||||
|
self.w1_a = values["lokr_w1_a"]
|
||||||
|
self.w1_b = values["lokr_w1_b"]
|
||||||
|
|
||||||
|
if "lokr_w2" in values:
|
||||||
|
self.w2 = values["lokr_w2"]
|
||||||
|
self.w2_a = None
|
||||||
|
self.w2_b = None
|
||||||
|
else:
|
||||||
|
self.w2 = None
|
||||||
|
self.w2_a = values["lokr_w2_a"]
|
||||||
|
self.w2_b = values["lokr_w2_b"]
|
||||||
|
|
||||||
|
if "lokr_t2" in values:
|
||||||
|
self.t2 = values["lokr_t2"]
|
||||||
|
else:
|
||||||
|
self.t2 = None
|
||||||
|
|
||||||
|
if "lokr_w1_b" in values:
|
||||||
|
self.rank = values["lokr_w1_b"].shape[0]
|
||||||
|
elif "lokr_w2_b" in values:
|
||||||
|
self.rank = values["lokr_w2_b"].shape[0]
|
||||||
|
else:
|
||||||
|
self.rank = None # unscaled
|
||||||
|
|
||||||
|
def get_weight(self):
|
||||||
|
w1 = self.w1
|
||||||
|
if w1 is None:
|
||||||
|
w1 = self.w1_a @ self.w1_b
|
||||||
|
|
||||||
|
w2 = self.w2
|
||||||
|
if w2 is None:
|
||||||
|
if self.t2 is None:
|
||||||
|
w2 = self.w2_a @ self.w2_b
|
||||||
|
else:
|
||||||
|
w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b)
|
||||||
|
|
||||||
|
if len(w2.shape) == 4:
|
||||||
|
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||||
|
w2 = w2.contiguous()
|
||||||
|
weight = torch.kron(w1, w2)
|
||||||
|
|
||||||
|
return weight
|
||||||
|
|
||||||
|
def calc_size(self) -> int:
|
||||||
|
model_size = super().calc_size()
|
||||||
|
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
|
||||||
|
if val is not None:
|
||||||
|
model_size += val.nelement() * val.element_size()
|
||||||
|
return model_size
|
||||||
|
|
||||||
|
def to(
|
||||||
|
self,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
):
|
||||||
|
super().to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
if self.w1 is not None:
|
||||||
|
self.w1 = self.w1.to(device=device, dtype=dtype)
|
||||||
|
else:
|
||||||
|
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||||
|
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
if self.w2 is not None:
|
||||||
|
self.w2 = self.w2.to(device=device, dtype=dtype)
|
||||||
|
else:
|
||||||
|
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||||
|
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
if self.t2 is not None:
|
||||||
|
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class FullLayer(LoRALayerBase):
|
||||||
|
# weight: torch.Tensor
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
layer_key: str,
|
||||||
|
values: dict,
|
||||||
|
):
|
||||||
|
super().__init__(layer_key, values)
|
||||||
|
|
||||||
|
self.weight = values["diff"]
|
||||||
|
|
||||||
|
if len(values.keys()) > 1:
|
||||||
|
_keys = list(values.keys())
|
||||||
|
_keys.remove("diff")
|
||||||
|
raise NotImplementedError(f"Unexpected keys in lora diff layer: {_keys}")
|
||||||
|
|
||||||
|
self.rank = None # unscaled
|
||||||
|
|
||||||
|
def get_weight(self):
|
||||||
|
return self.weight
|
||||||
|
|
||||||
|
def calc_size(self) -> int:
|
||||||
|
model_size = super().calc_size()
|
||||||
|
model_size += self.weight.nelement() * self.weight.element_size()
|
||||||
|
return model_size
|
||||||
|
|
||||||
|
def to(
|
||||||
|
self,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
):
|
||||||
|
super().to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix
|
||||||
|
class LoRAModelRaw: # (torch.nn.Module):
|
||||||
|
_name: str
|
||||||
|
layers: Dict[str, LoRALayer]
|
||||||
|
_device: torch.device
|
||||||
|
_dtype: torch.dtype
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
layers: Dict[str, LoRALayer],
|
||||||
|
device: torch.device,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
):
|
||||||
|
self._name = name
|
||||||
|
self._device = device or torch.cpu
|
||||||
|
self._dtype = dtype or torch.float32
|
||||||
|
self.layers = layers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return self._device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
return self._dtype
|
||||||
|
|
||||||
|
def to(
|
||||||
|
self,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
):
|
||||||
|
# TODO: try revert if exception?
|
||||||
|
for key, layer in self.layers.items():
|
||||||
|
layer.to(device=device, dtype=dtype)
|
||||||
|
self._device = device
|
||||||
|
self._dtype = dtype
|
||||||
|
|
||||||
|
def calc_size(self) -> int:
|
||||||
|
model_size = 0
|
||||||
|
for _, layer in self.layers.items():
|
||||||
|
model_size += layer.calc_size()
|
||||||
|
return model_size
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _convert_sdxl_compvis_keys(cls, state_dict):
|
||||||
|
new_state_dict = dict()
|
||||||
|
for full_key, value in state_dict.items():
|
||||||
|
if full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
|
||||||
|
continue # clip same
|
||||||
|
|
||||||
|
if not full_key.startswith("lora_unet_"):
|
||||||
|
raise NotImplementedError(f"Unknown prefix for sdxl lora key - {full_key}")
|
||||||
|
src_key = full_key.replace("lora_unet_", "")
|
||||||
|
try:
|
||||||
|
dst_key = None
|
||||||
|
while "_" in src_key:
|
||||||
|
if src_key in SDXL_UNET_COMPVIS_MAP:
|
||||||
|
dst_key = SDXL_UNET_COMPVIS_MAP[src_key]
|
||||||
|
break
|
||||||
|
src_key = "_".join(src_key.split("_")[:-1])
|
||||||
|
|
||||||
|
if dst_key is None:
|
||||||
|
raise Exception(f"Unknown sdxl lora key - {full_key}")
|
||||||
|
new_key = full_key.replace(src_key, dst_key)
|
||||||
|
except:
|
||||||
|
print(SDXL_UNET_COMPVIS_MAP)
|
||||||
|
raise
|
||||||
|
new_state_dict[new_key] = value
|
||||||
|
return new_state_dict
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_checkpoint(
|
||||||
|
cls,
|
||||||
|
file_path: Union[str, Path],
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
base_model: Optional[BaseModelType] = None,
|
||||||
|
):
|
||||||
|
device = device or torch.device("cpu")
|
||||||
|
dtype = dtype or torch.float32
|
||||||
|
|
||||||
|
if isinstance(file_path, str):
|
||||||
|
file_path = Path(file_path)
|
||||||
|
|
||||||
|
model = cls(
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
name=file_path.stem, # TODO:
|
||||||
|
layers=dict(),
|
||||||
|
)
|
||||||
|
|
||||||
|
if file_path.suffix == ".safetensors":
|
||||||
|
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
|
||||||
|
else:
|
||||||
|
state_dict = torch.load(file_path, map_location="cpu")
|
||||||
|
|
||||||
|
state_dict = cls._group_state(state_dict)
|
||||||
|
|
||||||
|
if base_model == BaseModelType.StableDiffusionXL:
|
||||||
|
state_dict = cls._convert_sdxl_compvis_keys(state_dict)
|
||||||
|
|
||||||
|
for layer_key, values in state_dict.items():
|
||||||
|
# lora and locon
|
||||||
|
if "lora_down.weight" in values:
|
||||||
|
layer = LoRALayer(layer_key, values)
|
||||||
|
|
||||||
|
# loha
|
||||||
|
elif "hada_w1_b" in values:
|
||||||
|
layer = LoHALayer(layer_key, values)
|
||||||
|
|
||||||
|
# lokr
|
||||||
|
elif "lokr_w1_b" in values or "lokr_w1" in values:
|
||||||
|
layer = LoKRLayer(layer_key, values)
|
||||||
|
|
||||||
|
elif "diff" in values:
|
||||||
|
layer = FullLayer(layer_key, values)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# TODO: ia3/... format
|
||||||
|
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
|
||||||
|
raise Exception("Unknown lora format!")
|
||||||
|
|
||||||
|
# lower memory consumption by removing already parsed layer values
|
||||||
|
state_dict[layer_key].clear()
|
||||||
|
|
||||||
|
layer.to(device=device, dtype=dtype)
|
||||||
|
model.layers[layer_key] = layer
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _group_state(state_dict: dict):
|
||||||
|
state_dict_groupped = dict()
|
||||||
|
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
stem, leaf = key.split(".", 1)
|
||||||
|
if stem not in state_dict_groupped:
|
||||||
|
state_dict_groupped[stem] = dict()
|
||||||
|
state_dict_groupped[stem][leaf] = value
|
||||||
|
|
||||||
|
return state_dict_groupped
|
||||||
|
|
||||||
|
|
||||||
|
# code from
|
||||||
|
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
|
||||||
|
def make_sdxl_unet_conversion_map():
|
||||||
|
unet_conversion_map_layer = []
|
||||||
|
|
||||||
|
for i in range(3): # num_blocks is 3 in sdxl
|
||||||
|
# loop over downblocks/upblocks
|
||||||
|
for j in range(2):
|
||||||
|
# loop over resnets/attentions for downblocks
|
||||||
|
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
||||||
|
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
||||||
|
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
||||||
|
|
||||||
|
if i < 3:
|
||||||
|
# no attention layers in down_blocks.3
|
||||||
|
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
||||||
|
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
||||||
|
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
||||||
|
|
||||||
|
for j in range(3):
|
||||||
|
# loop over resnets/attentions for upblocks
|
||||||
|
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
||||||
|
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
||||||
|
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
||||||
|
|
||||||
|
# if i > 0: commentout for sdxl
|
||||||
|
# no attention layers in up_blocks.0
|
||||||
|
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
||||||
|
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
||||||
|
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
||||||
|
|
||||||
|
if i < 3:
|
||||||
|
# no downsample in down_blocks.3
|
||||||
|
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
||||||
|
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
||||||
|
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||||
|
|
||||||
|
# no upsample in up_blocks.3
|
||||||
|
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||||
|
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
|
||||||
|
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||||
|
|
||||||
|
hf_mid_atn_prefix = "mid_block.attentions.0."
|
||||||
|
sd_mid_atn_prefix = "middle_block.1."
|
||||||
|
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
||||||
|
|
||||||
|
for j in range(2):
|
||||||
|
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
||||||
|
sd_mid_res_prefix = f"middle_block.{2*j}."
|
||||||
|
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||||
|
|
||||||
|
unet_conversion_map_resnet = [
|
||||||
|
# (stable-diffusion, HF Diffusers)
|
||||||
|
("in_layers.0.", "norm1."),
|
||||||
|
("in_layers.2.", "conv1."),
|
||||||
|
("out_layers.0.", "norm2."),
|
||||||
|
("out_layers.3.", "conv2."),
|
||||||
|
("emb_layers.1.", "time_emb_proj."),
|
||||||
|
("skip_connection.", "conv_shortcut."),
|
||||||
|
]
|
||||||
|
|
||||||
|
unet_conversion_map = []
|
||||||
|
for sd, hf in unet_conversion_map_layer:
|
||||||
|
if "resnets" in hf:
|
||||||
|
for sd_res, hf_res in unet_conversion_map_resnet:
|
||||||
|
unet_conversion_map.append((sd + sd_res, hf + hf_res))
|
||||||
|
else:
|
||||||
|
unet_conversion_map.append((sd, hf))
|
||||||
|
|
||||||
|
for j in range(2):
|
||||||
|
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
|
||||||
|
sd_time_embed_prefix = f"time_embed.{j*2}."
|
||||||
|
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
|
||||||
|
|
||||||
|
for j in range(2):
|
||||||
|
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
|
||||||
|
sd_label_embed_prefix = f"label_emb.0.{j*2}."
|
||||||
|
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
|
||||||
|
|
||||||
|
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
|
||||||
|
unet_conversion_map.append(("out.0.", "conv_norm_out."))
|
||||||
|
unet_conversion_map.append(("out.2.", "conv_out."))
|
||||||
|
|
||||||
|
return unet_conversion_map
|
||||||
|
|
||||||
|
|
||||||
|
SDXL_UNET_COMPVIS_MAP = {
|
||||||
|
f"{sd}".rstrip(".").replace(".", "_"): f"{hf}".rstrip(".").replace(".", "_")
|
||||||
|
for sd, hf in make_sdxl_unet_conversion_map()
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from enum import Enum
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Literal, Optional, Union
|
from typing import Literal, Optional, Union
|
||||||
|
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionPipeline
|
||||||
from .base import (
|
from .base import (
|
||||||
ModelConfigBase,
|
ModelConfigBase,
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
@@ -263,6 +264,8 @@ def _convert_ckpt_and_cache(
|
|||||||
weights = app_config.models_path / model_config.path
|
weights = app_config.models_path / model_config.path
|
||||||
config_file = app_config.root_path / model_config.config
|
config_file = app_config.root_path / model_config.config
|
||||||
output_path = Path(output_path)
|
output_path = Path(output_path)
|
||||||
|
variant = model_config.variant
|
||||||
|
pipeline_class = StableDiffusionInpaintPipeline if variant == "inpaint" else StableDiffusionPipeline
|
||||||
|
|
||||||
# return cached version if it exists
|
# return cached version if it exists
|
||||||
if output_path.exists():
|
if output_path.exists():
|
||||||
@@ -289,6 +292,7 @@ def _convert_ckpt_and_cache(
|
|||||||
original_config_file=config_file,
|
original_config_file=config_file,
|
||||||
extract_ema=True,
|
extract_ema=True,
|
||||||
scan_needed=True,
|
scan_needed=True,
|
||||||
|
pipeline_class=pipeline_class,
|
||||||
from_safetensors=weights.suffix == ".safetensors",
|
from_safetensors=weights.suffix == ".safetensors",
|
||||||
precision=torch_dtype(choose_torch_device()),
|
precision=torch_dtype(choose_torch_device()),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
|||||||
@@ -78,10 +78,9 @@ class InvokeAIDiffuserComponent:
|
|||||||
self.cross_attention_control_context = None
|
self.cross_attention_control_context = None
|
||||||
self.sequential_guidance = config.sequential_guidance
|
self.sequential_guidance = config.sequential_guidance
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def custom_attention_context(
|
def custom_attention_context(
|
||||||
cls,
|
self,
|
||||||
unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs
|
unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs
|
||||||
extra_conditioning_info: Optional[ExtraConditioningInfo],
|
extra_conditioning_info: Optional[ExtraConditioningInfo],
|
||||||
step_count: int,
|
step_count: int,
|
||||||
@@ -91,18 +90,19 @@ class InvokeAIDiffuserComponent:
|
|||||||
old_attn_processors = unet.attn_processors
|
old_attn_processors = unet.attn_processors
|
||||||
# Load lora conditions into the model
|
# Load lora conditions into the model
|
||||||
if extra_conditioning_info.wants_cross_attention_control:
|
if extra_conditioning_info.wants_cross_attention_control:
|
||||||
cross_attention_control_context = Context(
|
self.cross_attention_control_context = Context(
|
||||||
arguments=extra_conditioning_info.cross_attention_control_args,
|
arguments=extra_conditioning_info.cross_attention_control_args,
|
||||||
step_count=step_count,
|
step_count=step_count,
|
||||||
)
|
)
|
||||||
setup_cross_attention_control_attention_processors(
|
setup_cross_attention_control_attention_processors(
|
||||||
unet,
|
unet,
|
||||||
cross_attention_control_context,
|
self.cross_attention_control_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield None
|
yield None
|
||||||
finally:
|
finally:
|
||||||
|
self.cross_attention_control_context = None
|
||||||
if old_attn_processors is not None:
|
if old_attn_processors is not None:
|
||||||
unet.set_attn_processor(old_attn_processors)
|
unet.set_attn_processor(old_attn_processors)
|
||||||
# TODO resuscitate attention map saving
|
# TODO resuscitate attention map saving
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
|
from packaging import version
|
||||||
|
import platform
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import autocast
|
from torch import autocast
|
||||||
@@ -30,7 +32,7 @@ def choose_precision(device: torch.device) -> str:
|
|||||||
device_name = torch.cuda.get_device_name(device)
|
device_name = torch.cuda.get_device_name(device)
|
||||||
if not ("GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name):
|
if not ("GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name):
|
||||||
return "float16"
|
return "float16"
|
||||||
elif device.type == "mps":
|
elif device.type == "mps" and version.parse(platform.mac_ver()[0]) < version.parse("14.0.0"):
|
||||||
return "float16"
|
return "float16"
|
||||||
return "float32"
|
return "float32"
|
||||||
|
|
||||||
|
|||||||
169
invokeai/frontend/web/dist/assets/App-3594329a.js
vendored
Normal file
169
invokeai/frontend/web/dist/assets/App-3594329a.js
vendored
Normal file
File diff suppressed because one or more lines are too long
169
invokeai/frontend/web/dist/assets/App-44cdaaf3.js
vendored
169
invokeai/frontend/web/dist/assets/App-44cdaaf3.js
vendored
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -1,4 +1,4 @@
|
|||||||
import{A as m,f$ as Je,z as y,a4 as Ka,g0 as Xa,af as va,aj as d,g1 as b,g2 as t,g3 as Ya,g4 as h,g5 as ua,g6 as Ja,g7 as Qa,aI as Za,g8 as et,ad as rt,g9 as at}from"./index-18f2f740.js";import{s as fa,n as o,t as tt,o as ha,p as ot,q as ma,v as ga,w as ya,x as it,y as Sa,z as pa,A as xr,B as nt,D as lt,E as st,F as xa,G as $a,H as ka,J as dt,K as _a,L as ct,M as bt,N as vt,O as ut,Q as wa,R as ft,S as ht,T as mt,U as gt,V as yt,W as St,e as pt,X as xt}from"./MantineProvider-b20a2267.js";var za=String.raw,Ca=za`
|
import{B as m,g7 as Je,A as y,a5 as Ka,g8 as Xa,af as va,aj as d,g9 as b,ga as t,gb as Ya,gc as h,gd as ua,ge as Ja,gf as Qa,aL as Za,gg as et,ad as rt,gh as at}from"./index-de589048.js";import{s as fa,n as o,t as tt,o as ha,p as ot,q as ma,v as ga,w as ya,x as it,y as Sa,z as pa,A as xr,B as nt,D as lt,E as st,F as xa,G as $a,H as ka,J as dt,K as _a,L as ct,M as bt,N as vt,O as ut,Q as wa,R as ft,S as ht,T as mt,U as gt,V as yt,W as St,e as pt,X as xt}from"./menu-11348abc.js";var za=String.raw,Ca=za`
|
||||||
:root,
|
:root,
|
||||||
:host {
|
:host {
|
||||||
--chakra-vh: 100vh;
|
--chakra-vh: 100vh;
|
||||||
125
invokeai/frontend/web/dist/assets/index-18f2f740.js
vendored
125
invokeai/frontend/web/dist/assets/index-18f2f740.js
vendored
File diff suppressed because one or more lines are too long
151
invokeai/frontend/web/dist/assets/index-de589048.js
vendored
Normal file
151
invokeai/frontend/web/dist/assets/index-de589048.js
vendored
Normal file
File diff suppressed because one or more lines are too long
1
invokeai/frontend/web/dist/assets/menu-11348abc.js
vendored
Normal file
1
invokeai/frontend/web/dist/assets/menu-11348abc.js
vendored
Normal file
File diff suppressed because one or more lines are too long
2
invokeai/frontend/web/dist/index.html
vendored
2
invokeai/frontend/web/dist/index.html
vendored
@@ -12,7 +12,7 @@
|
|||||||
margin: 0;
|
margin: 0;
|
||||||
}
|
}
|
||||||
</style>
|
</style>
|
||||||
<script type="module" crossorigin src="./assets/index-18f2f740.js"></script>
|
<script type="module" crossorigin src="./assets/index-de589048.js"></script>
|
||||||
</head>
|
</head>
|
||||||
|
|
||||||
<body dir="ltr">
|
<body dir="ltr">
|
||||||
|
|||||||
3
invokeai/frontend/web/dist/locales/en.json
vendored
3
invokeai/frontend/web/dist/locales/en.json
vendored
@@ -124,7 +124,8 @@
|
|||||||
"deleteImageBin": "Deleted images will be sent to your operating system's Bin.",
|
"deleteImageBin": "Deleted images will be sent to your operating system's Bin.",
|
||||||
"deleteImagePermanent": "Deleted images cannot be restored.",
|
"deleteImagePermanent": "Deleted images cannot be restored.",
|
||||||
"images": "Images",
|
"images": "Images",
|
||||||
"assets": "Assets"
|
"assets": "Assets",
|
||||||
|
"autoAssignBoardOnClick": "Auto-Assign Board on Click"
|
||||||
},
|
},
|
||||||
"hotkeys": {
|
"hotkeys": {
|
||||||
"keyboardShortcuts": "Keyboard Shortcuts",
|
"keyboardShortcuts": "Keyboard Shortcuts",
|
||||||
|
|||||||
@@ -23,7 +23,7 @@
|
|||||||
"dev": "concurrently \"vite dev\" \"yarn run theme:watch\"",
|
"dev": "concurrently \"vite dev\" \"yarn run theme:watch\"",
|
||||||
"dev:host": "concurrently \"vite dev --host\" \"yarn run theme:watch\"",
|
"dev:host": "concurrently \"vite dev --host\" \"yarn run theme:watch\"",
|
||||||
"build": "yarn run lint && vite build",
|
"build": "yarn run lint && vite build",
|
||||||
"typegen": "npx ts-node scripts/typegen.ts",
|
"typegen": "node scripts/typegen.js",
|
||||||
"preview": "vite preview",
|
"preview": "vite preview",
|
||||||
"lint:madge": "madge --circular src/main.tsx",
|
"lint:madge": "madge --circular src/main.tsx",
|
||||||
"lint:eslint": "eslint --max-warnings=0 .",
|
"lint:eslint": "eslint --max-warnings=0 .",
|
||||||
|
|||||||
@@ -124,7 +124,8 @@
|
|||||||
"deleteImageBin": "Deleted images will be sent to your operating system's Bin.",
|
"deleteImageBin": "Deleted images will be sent to your operating system's Bin.",
|
||||||
"deleteImagePermanent": "Deleted images cannot be restored.",
|
"deleteImagePermanent": "Deleted images cannot be restored.",
|
||||||
"images": "Images",
|
"images": "Images",
|
||||||
"assets": "Assets"
|
"assets": "Assets",
|
||||||
|
"autoAssignBoardOnClick": "Auto-Assign Board on Click"
|
||||||
},
|
},
|
||||||
"hotkeys": {
|
"hotkeys": {
|
||||||
"keyboardShortcuts": "Keyboard Shortcuts",
|
"keyboardShortcuts": "Keyboard Shortcuts",
|
||||||
|
|||||||
@@ -4,8 +4,9 @@ import { appStarted } from 'app/store/middleware/listenerMiddleware/listeners/ap
|
|||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { PartialAppConfig } from 'app/types/invokeai';
|
import { PartialAppConfig } from 'app/types/invokeai';
|
||||||
import ImageUploader from 'common/components/ImageUploader';
|
import ImageUploader from 'common/components/ImageUploader';
|
||||||
|
import ChangeBoardModal from 'features/changeBoardModal/components/ChangeBoardModal';
|
||||||
|
import DeleteImageModal from 'features/deleteImageModal/components/DeleteImageModal';
|
||||||
import GalleryDrawer from 'features/gallery/components/GalleryPanel';
|
import GalleryDrawer from 'features/gallery/components/GalleryPanel';
|
||||||
import DeleteImageModal from 'features/imageDeletion/components/DeleteImageModal';
|
|
||||||
import SiteHeader from 'features/system/components/SiteHeader';
|
import SiteHeader from 'features/system/components/SiteHeader';
|
||||||
import { configChanged } from 'features/system/store/configSlice';
|
import { configChanged } from 'features/system/store/configSlice';
|
||||||
import { languageSelector } from 'features/system/store/systemSelectors';
|
import { languageSelector } from 'features/system/store/systemSelectors';
|
||||||
@@ -16,7 +17,6 @@ import ParametersDrawer from 'features/ui/components/ParametersDrawer';
|
|||||||
import i18n from 'i18n';
|
import i18n from 'i18n';
|
||||||
import { size } from 'lodash-es';
|
import { size } from 'lodash-es';
|
||||||
import { ReactNode, memo, useEffect } from 'react';
|
import { ReactNode, memo, useEffect } from 'react';
|
||||||
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
|
|
||||||
import GlobalHotkeys from './GlobalHotkeys';
|
import GlobalHotkeys from './GlobalHotkeys';
|
||||||
import Toaster from './Toaster';
|
import Toaster from './Toaster';
|
||||||
|
|
||||||
@@ -84,7 +84,7 @@ const App = ({ config = DEFAULT_CONFIG, headerComponent }: Props) => {
|
|||||||
</Portal>
|
</Portal>
|
||||||
</Grid>
|
</Grid>
|
||||||
<DeleteImageModal />
|
<DeleteImageModal />
|
||||||
<UpdateImageBoardModal />
|
<ChangeBoardModal />
|
||||||
<Toaster />
|
<Toaster />
|
||||||
<GlobalHotkeys />
|
<GlobalHotkeys />
|
||||||
</>
|
</>
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ const DragPreview = (props: OverlayDragImageProps) => {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (props.dragData.payloadType === 'IMAGE_NAMES') {
|
if (props.dragData.payloadType === 'IMAGE_DTOS') {
|
||||||
return (
|
return (
|
||||||
<Flex
|
<Flex
|
||||||
sx={{
|
sx={{
|
||||||
@@ -71,7 +71,7 @@ const DragPreview = (props: OverlayDragImageProps) => {
|
|||||||
...STYLES,
|
...STYLES,
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<Heading>{props.dragData.payload.image_names.length}</Heading>
|
<Heading>{props.dragData.payload.imageDTOs.length}</Heading>
|
||||||
<Heading size="sm">Images</Heading>
|
<Heading size="sm">Images</Heading>
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -18,27 +18,32 @@ import {
|
|||||||
DragStartEvent,
|
DragStartEvent,
|
||||||
TypesafeDraggableData,
|
TypesafeDraggableData,
|
||||||
} from './typesafeDnd';
|
} from './typesafeDnd';
|
||||||
|
import { logger } from 'app/logging/logger';
|
||||||
|
|
||||||
type ImageDndContextProps = PropsWithChildren;
|
type ImageDndContextProps = PropsWithChildren;
|
||||||
|
|
||||||
const ImageDndContext = (props: ImageDndContextProps) => {
|
const ImageDndContext = (props: ImageDndContextProps) => {
|
||||||
const [activeDragData, setActiveDragData] =
|
const [activeDragData, setActiveDragData] =
|
||||||
useState<TypesafeDraggableData | null>(null);
|
useState<TypesafeDraggableData | null>(null);
|
||||||
|
const log = logger('images');
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const handleDragStart = useCallback((event: DragStartEvent) => {
|
const handleDragStart = useCallback(
|
||||||
console.log('dragStart', event.active.data.current);
|
(event: DragStartEvent) => {
|
||||||
const activeData = event.active.data.current;
|
log.trace({ dragData: event.active.data.current }, 'Drag started');
|
||||||
if (!activeData) {
|
const activeData = event.active.data.current;
|
||||||
return;
|
if (!activeData) {
|
||||||
}
|
return;
|
||||||
setActiveDragData(activeData);
|
}
|
||||||
}, []);
|
setActiveDragData(activeData);
|
||||||
|
},
|
||||||
|
[log]
|
||||||
|
);
|
||||||
|
|
||||||
const handleDragEnd = useCallback(
|
const handleDragEnd = useCallback(
|
||||||
(event: DragEndEvent) => {
|
(event: DragEndEvent) => {
|
||||||
console.log('dragEnd', event.active.data.current);
|
log.trace({ dragData: event.active.data.current }, 'Drag ended');
|
||||||
const overData = event.over?.data.current;
|
const overData = event.over?.data.current;
|
||||||
if (!activeDragData || !overData) {
|
if (!activeDragData || !overData) {
|
||||||
return;
|
return;
|
||||||
@@ -46,7 +51,7 @@ const ImageDndContext = (props: ImageDndContextProps) => {
|
|||||||
dispatch(dndDropped({ overData, activeData: activeDragData }));
|
dispatch(dndDropped({ overData, activeData: activeDragData }));
|
||||||
setActiveDragData(null);
|
setActiveDragData(null);
|
||||||
},
|
},
|
||||||
[activeDragData, dispatch]
|
[activeDragData, dispatch, log]
|
||||||
);
|
);
|
||||||
|
|
||||||
const mouseSensor = useSensor(MouseSensor, {
|
const mouseSensor = useSensor(MouseSensor, {
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ import {
|
|||||||
useDraggable as useOriginalDraggable,
|
useDraggable as useOriginalDraggable,
|
||||||
useDroppable as useOriginalDroppable,
|
useDroppable as useOriginalDroppable,
|
||||||
} from '@dnd-kit/core';
|
} from '@dnd-kit/core';
|
||||||
import { BoardId } from 'features/gallery/store/types';
|
|
||||||
import { ImageDTO } from 'services/api/types';
|
import { ImageDTO } from 'services/api/types';
|
||||||
|
|
||||||
type BaseDropData = {
|
type BaseDropData = {
|
||||||
@@ -54,9 +53,13 @@ export type AddToBatchDropData = BaseDropData & {
|
|||||||
actionType: 'ADD_TO_BATCH';
|
actionType: 'ADD_TO_BATCH';
|
||||||
};
|
};
|
||||||
|
|
||||||
export type MoveBoardDropData = BaseDropData & {
|
export type AddToBoardDropData = BaseDropData & {
|
||||||
actionType: 'MOVE_BOARD';
|
actionType: 'ADD_TO_BOARD';
|
||||||
context: { boardId: BoardId };
|
context: { boardId: string };
|
||||||
|
};
|
||||||
|
|
||||||
|
export type RemoveFromBoardDropData = BaseDropData & {
|
||||||
|
actionType: 'REMOVE_FROM_BOARD';
|
||||||
};
|
};
|
||||||
|
|
||||||
export type TypesafeDroppableData =
|
export type TypesafeDroppableData =
|
||||||
@@ -67,7 +70,8 @@ export type TypesafeDroppableData =
|
|||||||
| NodesImageDropData
|
| NodesImageDropData
|
||||||
| AddToBatchDropData
|
| AddToBatchDropData
|
||||||
| NodesMultiImageDropData
|
| NodesMultiImageDropData
|
||||||
| MoveBoardDropData;
|
| AddToBoardDropData
|
||||||
|
| RemoveFromBoardDropData;
|
||||||
|
|
||||||
type BaseDragData = {
|
type BaseDragData = {
|
||||||
id: string;
|
id: string;
|
||||||
@@ -78,14 +82,12 @@ export type ImageDraggableData = BaseDragData & {
|
|||||||
payload: { imageDTO: ImageDTO };
|
payload: { imageDTO: ImageDTO };
|
||||||
};
|
};
|
||||||
|
|
||||||
export type ImageNamesDraggableData = BaseDragData & {
|
export type ImageDTOsDraggableData = BaseDragData & {
|
||||||
payloadType: 'IMAGE_NAMES';
|
payloadType: 'IMAGE_DTOS';
|
||||||
payload: { image_names: string[] };
|
payload: { imageDTOs: ImageDTO[] };
|
||||||
};
|
};
|
||||||
|
|
||||||
export type TypesafeDraggableData =
|
export type TypesafeDraggableData = ImageDraggableData | ImageDTOsDraggableData;
|
||||||
| ImageDraggableData
|
|
||||||
| ImageNamesDraggableData;
|
|
||||||
|
|
||||||
interface UseDroppableTypesafeArguments
|
interface UseDroppableTypesafeArguments
|
||||||
extends Omit<UseDroppableArguments, 'data'> {
|
extends Omit<UseDroppableArguments, 'data'> {
|
||||||
@@ -156,14 +158,39 @@ export const isValidDrop = (
|
|||||||
case 'SET_NODES_IMAGE':
|
case 'SET_NODES_IMAGE':
|
||||||
return payloadType === 'IMAGE_DTO';
|
return payloadType === 'IMAGE_DTO';
|
||||||
case 'SET_MULTI_NODES_IMAGE':
|
case 'SET_MULTI_NODES_IMAGE':
|
||||||
return payloadType === 'IMAGE_DTO' || 'IMAGE_NAMES';
|
return payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
|
||||||
case 'ADD_TO_BATCH':
|
case 'ADD_TO_BATCH':
|
||||||
return payloadType === 'IMAGE_DTO' || 'IMAGE_NAMES';
|
return payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
|
||||||
case 'MOVE_BOARD': {
|
case 'ADD_TO_BOARD': {
|
||||||
// If the board is the same, don't allow the drop
|
// If the board is the same, don't allow the drop
|
||||||
|
|
||||||
// Check the payload types
|
// Check the payload types
|
||||||
const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_NAMES';
|
const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
|
||||||
|
if (!isPayloadValid) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the image's board is the board we are dragging onto
|
||||||
|
if (payloadType === 'IMAGE_DTO') {
|
||||||
|
const { imageDTO } = active.data.current.payload;
|
||||||
|
const currentBoard = imageDTO.board_id ?? 'none';
|
||||||
|
const destinationBoard = overData.context.boardId;
|
||||||
|
|
||||||
|
return currentBoard !== destinationBoard;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (payloadType === 'IMAGE_DTOS') {
|
||||||
|
// TODO (multi-select)
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
case 'REMOVE_FROM_BOARD': {
|
||||||
|
// If the board is the same, don't allow the drop
|
||||||
|
|
||||||
|
// Check the payload types
|
||||||
|
const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
|
||||||
if (!isPayloadValid) {
|
if (!isPayloadValid) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@@ -172,20 +199,16 @@ export const isValidDrop = (
|
|||||||
if (payloadType === 'IMAGE_DTO') {
|
if (payloadType === 'IMAGE_DTO') {
|
||||||
const { imageDTO } = active.data.current.payload;
|
const { imageDTO } = active.data.current.payload;
|
||||||
const currentBoard = imageDTO.board_id;
|
const currentBoard = imageDTO.board_id;
|
||||||
const destinationBoard = overData.context.boardId;
|
|
||||||
|
|
||||||
const isSameBoard = currentBoard === destinationBoard;
|
return currentBoard !== 'none';
|
||||||
const isDestinationValid = !currentBoard ? destinationBoard : true;
|
|
||||||
|
|
||||||
return !isSameBoard && isDestinationValid;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (payloadType === 'IMAGE_NAMES') {
|
if (payloadType === 'IMAGE_DTOS') {
|
||||||
// TODO (multi-select)
|
// TODO (multi-select)
|
||||||
return false;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return false;
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
|
import { Middleware } from '@reduxjs/toolkit';
|
||||||
import { store } from 'app/store/store';
|
import { store } from 'app/store/store';
|
||||||
|
import { PartialAppConfig } from 'app/types/invokeai';
|
||||||
import React, {
|
import React, {
|
||||||
lazy,
|
lazy,
|
||||||
memo,
|
memo,
|
||||||
@@ -7,16 +9,11 @@ import React, {
|
|||||||
useEffect,
|
useEffect,
|
||||||
} from 'react';
|
} from 'react';
|
||||||
import { Provider } from 'react-redux';
|
import { Provider } from 'react-redux';
|
||||||
|
|
||||||
import { PartialAppConfig } from 'app/types/invokeai';
|
|
||||||
import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares';
|
import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares';
|
||||||
import Loading from '../../common/components/Loading/Loading';
|
import { $authToken, $baseUrl, $projectId } from 'services/api/client';
|
||||||
|
|
||||||
import { Middleware } from '@reduxjs/toolkit';
|
|
||||||
import { $authToken, $baseUrl } from 'services/api/client';
|
|
||||||
import { socketMiddleware } from 'services/events/middleware';
|
import { socketMiddleware } from 'services/events/middleware';
|
||||||
|
import Loading from '../../common/components/Loading/Loading';
|
||||||
import '../../i18n';
|
import '../../i18n';
|
||||||
import { AddImageToBoardContextProvider } from '../contexts/AddImageToBoardContext';
|
|
||||||
import ImageDndContext from './ImageDnd/ImageDndContext';
|
import ImageDndContext from './ImageDnd/ImageDndContext';
|
||||||
|
|
||||||
const App = lazy(() => import('./App'));
|
const App = lazy(() => import('./App'));
|
||||||
@@ -37,6 +34,7 @@ const InvokeAIUI = ({
|
|||||||
config,
|
config,
|
||||||
headerComponent,
|
headerComponent,
|
||||||
middleware,
|
middleware,
|
||||||
|
projectId,
|
||||||
}: Props) => {
|
}: Props) => {
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
// configure API client token
|
// configure API client token
|
||||||
@@ -49,6 +47,11 @@ const InvokeAIUI = ({
|
|||||||
$baseUrl.set(apiUrl);
|
$baseUrl.set(apiUrl);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// configure API client project header
|
||||||
|
if (projectId) {
|
||||||
|
$projectId.set(projectId);
|
||||||
|
}
|
||||||
|
|
||||||
// reset dynamically added middlewares
|
// reset dynamically added middlewares
|
||||||
resetMiddlewares();
|
resetMiddlewares();
|
||||||
|
|
||||||
@@ -68,8 +71,9 @@ const InvokeAIUI = ({
|
|||||||
// Reset the API client token and base url on unmount
|
// Reset the API client token and base url on unmount
|
||||||
$baseUrl.set(undefined);
|
$baseUrl.set(undefined);
|
||||||
$authToken.set(undefined);
|
$authToken.set(undefined);
|
||||||
|
$projectId.set(undefined);
|
||||||
};
|
};
|
||||||
}, [apiUrl, token, middleware]);
|
}, [apiUrl, token, middleware, projectId]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<React.StrictMode>
|
<React.StrictMode>
|
||||||
@@ -77,9 +81,7 @@ const InvokeAIUI = ({
|
|||||||
<React.Suspense fallback={<Loading />}>
|
<React.Suspense fallback={<Loading />}>
|
||||||
<ThemeLocaleProvider>
|
<ThemeLocaleProvider>
|
||||||
<ImageDndContext>
|
<ImageDndContext>
|
||||||
<AddImageToBoardContextProvider>
|
<App config={config} headerComponent={headerComponent} />
|
||||||
<App config={config} headerComponent={headerComponent} />
|
|
||||||
</AddImageToBoardContextProvider>
|
|
||||||
</ImageDndContext>
|
</ImageDndContext>
|
||||||
</ThemeLocaleProvider>
|
</ThemeLocaleProvider>
|
||||||
</React.Suspense>
|
</React.Suspense>
|
||||||
|
|||||||
@@ -1,91 +0,0 @@
|
|||||||
import { useDisclosure } from '@chakra-ui/react';
|
|
||||||
import { PropsWithChildren, createContext, useCallback, useState } from 'react';
|
|
||||||
import { ImageDTO } from 'services/api/types';
|
|
||||||
import { imagesApi } from 'services/api/endpoints/images';
|
|
||||||
import { useAppDispatch } from '../store/storeHooks';
|
|
||||||
|
|
||||||
export type ImageUsage = {
|
|
||||||
isInitialImage: boolean;
|
|
||||||
isCanvasImage: boolean;
|
|
||||||
isNodesImage: boolean;
|
|
||||||
isControlNetImage: boolean;
|
|
||||||
};
|
|
||||||
|
|
||||||
type AddImageToBoardContextValue = {
|
|
||||||
/**
|
|
||||||
* Whether the move image dialog is open.
|
|
||||||
*/
|
|
||||||
isOpen: boolean;
|
|
||||||
/**
|
|
||||||
* Closes the move image dialog.
|
|
||||||
*/
|
|
||||||
onClose: () => void;
|
|
||||||
/**
|
|
||||||
* The image pending movement
|
|
||||||
*/
|
|
||||||
image?: ImageDTO;
|
|
||||||
onClickAddToBoard: (image: ImageDTO) => void;
|
|
||||||
handleAddToBoard: (boardId: string) => void;
|
|
||||||
};
|
|
||||||
|
|
||||||
export const AddImageToBoardContext =
|
|
||||||
createContext<AddImageToBoardContextValue>({
|
|
||||||
isOpen: false,
|
|
||||||
onClose: () => undefined,
|
|
||||||
onClickAddToBoard: () => undefined,
|
|
||||||
handleAddToBoard: () => undefined,
|
|
||||||
});
|
|
||||||
|
|
||||||
type Props = PropsWithChildren;
|
|
||||||
|
|
||||||
export const AddImageToBoardContextProvider = (props: Props) => {
|
|
||||||
const [imageToMove, setImageToMove] = useState<ImageDTO>();
|
|
||||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
|
|
||||||
// Clean up after deleting or dismissing the modal
|
|
||||||
const closeAndClearImageToDelete = useCallback(() => {
|
|
||||||
setImageToMove(undefined);
|
|
||||||
onClose();
|
|
||||||
}, [onClose]);
|
|
||||||
|
|
||||||
const onClickAddToBoard = useCallback(
|
|
||||||
(image?: ImageDTO) => {
|
|
||||||
if (!image) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
setImageToMove(image);
|
|
||||||
onOpen();
|
|
||||||
},
|
|
||||||
[setImageToMove, onOpen]
|
|
||||||
);
|
|
||||||
|
|
||||||
const handleAddToBoard = useCallback(
|
|
||||||
(boardId: string) => {
|
|
||||||
if (imageToMove) {
|
|
||||||
dispatch(
|
|
||||||
imagesApi.endpoints.addImageToBoard.initiate({
|
|
||||||
imageDTO: imageToMove,
|
|
||||||
board_id: boardId,
|
|
||||||
})
|
|
||||||
);
|
|
||||||
closeAndClearImageToDelete();
|
|
||||||
}
|
|
||||||
},
|
|
||||||
[dispatch, closeAndClearImageToDelete, imageToMove]
|
|
||||||
);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<AddImageToBoardContext.Provider
|
|
||||||
value={{
|
|
||||||
isOpen,
|
|
||||||
image: imageToMove,
|
|
||||||
onClose: closeAndClearImageToDelete,
|
|
||||||
onClickAddToBoard,
|
|
||||||
handleAddToBoard,
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
{props.children}
|
|
||||||
</AddImageToBoardContext.Provider>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
import { createContext } from 'react';
|
|
||||||
|
|
||||||
type VoidFunc = () => void;
|
|
||||||
|
|
||||||
type ImageUploaderTriggerContextType = VoidFunc | null;
|
|
||||||
|
|
||||||
export const ImageUploaderTriggerContext =
|
|
||||||
createContext<ImageUploaderTriggerContextType>(null);
|
|
||||||
@@ -23,6 +23,6 @@ const serializationDenylist: {
|
|||||||
};
|
};
|
||||||
|
|
||||||
export const serialize: SerializeFunction = (data, key) => {
|
export const serialize: SerializeFunction = (data, key) => {
|
||||||
const result = omit(data, serializationDenylist[key]);
|
const result = omit(data, serializationDenylist[key] ?? []);
|
||||||
return JSON.stringify(result);
|
return JSON.stringify(result);
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -27,7 +27,8 @@ import {
|
|||||||
addImageDeletedFulfilledListener,
|
addImageDeletedFulfilledListener,
|
||||||
addImageDeletedPendingListener,
|
addImageDeletedPendingListener,
|
||||||
addImageDeletedRejectedListener,
|
addImageDeletedRejectedListener,
|
||||||
addRequestedImageDeletionListener,
|
addRequestedSingleImageDeletionListener,
|
||||||
|
addRequestedMultipleImageDeletionListener,
|
||||||
} from './listeners/imageDeleted';
|
} from './listeners/imageDeleted';
|
||||||
import { addImageDroppedListener } from './listeners/imageDropped';
|
import { addImageDroppedListener } from './listeners/imageDropped';
|
||||||
import {
|
import {
|
||||||
@@ -111,7 +112,8 @@ addImageUploadedRejectedListener();
|
|||||||
addInitialImageSelectedListener();
|
addInitialImageSelectedListener();
|
||||||
|
|
||||||
// Image deleted
|
// Image deleted
|
||||||
addRequestedImageDeletionListener();
|
addRequestedSingleImageDeletionListener();
|
||||||
|
addRequestedMultipleImageDeletionListener();
|
||||||
addImageDeletedPendingListener();
|
addImageDeletedPendingListener();
|
||||||
addImageDeletedFulfilledListener();
|
addImageDeletedFulfilledListener();
|
||||||
addImageDeletedRejectedListener();
|
addImageDeletedRejectedListener();
|
||||||
|
|||||||
@@ -1,12 +1,10 @@
|
|||||||
import { createAction } from '@reduxjs/toolkit';
|
import { createAction } from '@reduxjs/toolkit';
|
||||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||||
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
||||||
import {
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
ImageCache,
|
|
||||||
getListImagesUrl,
|
|
||||||
imagesApi,
|
|
||||||
} from 'services/api/endpoints/images';
|
|
||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
|
import { getListImagesUrl, imagesAdapter } from 'services/api/util';
|
||||||
|
import { ImageCache } from 'services/api/types';
|
||||||
|
|
||||||
export const appStarted = createAction('app/appStarted');
|
export const appStarted = createAction('app/appStarted');
|
||||||
|
|
||||||
@@ -34,7 +32,8 @@ export const addFirstListImagesListener = () => {
|
|||||||
|
|
||||||
if (data.ids.length > 0) {
|
if (data.ids.length > 0) {
|
||||||
// Select the first image
|
// Select the first image
|
||||||
dispatch(imageSelected(data.ids[0] as string));
|
const firstImage = imagesAdapter.getSelectors().selectAll(data)[0];
|
||||||
|
dispatch(imageSelected(firstImage ?? null));
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -18,7 +18,9 @@ export const addAppConfigReceivedListener = () => {
|
|||||||
const infillMethod = getState().generation.infillMethod;
|
const infillMethod = getState().generation.infillMethod;
|
||||||
|
|
||||||
if (!infill_methods.includes(infillMethod)) {
|
if (!infill_methods.includes(infillMethod)) {
|
||||||
dispatch(setInfillMethod(infill_methods[0]));
|
// if there is no infill method, set it to the first one
|
||||||
|
// if there is no first one... god help us
|
||||||
|
dispatch(setInfillMethod(infill_methods[0] as string));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!nsfw_methods.includes('nsfw_checker')) {
|
if (!nsfw_methods.includes('nsfw_checker')) {
|
||||||
|
|||||||
@@ -1,14 +1,14 @@
|
|||||||
import { resetCanvas } from 'features/canvas/store/canvasSlice';
|
import { resetCanvas } from 'features/canvas/store/canvasSlice';
|
||||||
import { controlNetReset } from 'features/controlNet/store/controlNetSlice';
|
import { controlNetReset } from 'features/controlNet/store/controlNetSlice';
|
||||||
import { getImageUsage } from 'features/imageDeletion/store/imageDeletionSelectors';
|
import { getImageUsage } from 'features/deleteImageModal/store/selectors';
|
||||||
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
|
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
|
||||||
import { clearInitialImage } from 'features/parameters/store/generationSlice';
|
import { clearInitialImage } from 'features/parameters/store/generationSlice';
|
||||||
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
import { boardsApi } from '../../../../../services/api/endpoints/boards';
|
|
||||||
|
|
||||||
export const addDeleteBoardAndImagesFulfilledListener = () => {
|
export const addDeleteBoardAndImagesFulfilledListener = () => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
matcher: boardsApi.endpoints.deleteBoardAndImages.matchFulfilled,
|
matcher: imagesApi.endpoints.deleteBoardAndImages.matchFulfilled,
|
||||||
effect: async (action, { dispatch, getState }) => {
|
effect: async (action, { dispatch, getState }) => {
|
||||||
const { deleted_images } = action.payload;
|
const { deleted_images } = action.payload;
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import {
|
|||||||
} from 'features/gallery/store/types';
|
} from 'features/gallery/store/types';
|
||||||
import { imagesApi } from 'services/api/endpoints/images';
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
|
import { imagesSelectors } from 'services/api/util';
|
||||||
|
|
||||||
export const addBoardIdSelectedListener = () => {
|
export const addBoardIdSelectedListener = () => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
@@ -52,8 +53,9 @@ export const addBoardIdSelectedListener = () => {
|
|||||||
queryArgs
|
queryArgs
|
||||||
)(getState());
|
)(getState());
|
||||||
|
|
||||||
if (boardImagesData?.ids.length) {
|
if (boardImagesData) {
|
||||||
dispatch(imageSelected((boardImagesData.ids[0] as string) ?? null));
|
const firstImage = imagesSelectors.selectAll(boardImagesData)[0];
|
||||||
|
dispatch(imageSelected(firstImage ?? null));
|
||||||
} else {
|
} else {
|
||||||
// board has no images - deselect
|
// board has no images - deselect
|
||||||
dispatch(imageSelected(null));
|
dispatch(imageSelected(null));
|
||||||
|
|||||||
@@ -26,6 +26,8 @@ export const addCanvasSavedToGalleryListener = () => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const { autoAddBoardId } = state.gallery;
|
||||||
|
|
||||||
dispatch(
|
dispatch(
|
||||||
imagesApi.endpoints.uploadImage.initiate({
|
imagesApi.endpoints.uploadImage.initiate({
|
||||||
file: new File([blob], 'savedCanvas.png', {
|
file: new File([blob], 'savedCanvas.png', {
|
||||||
@@ -33,7 +35,7 @@ export const addCanvasSavedToGalleryListener = () => {
|
|||||||
}),
|
}),
|
||||||
image_category: 'general',
|
image_category: 'general',
|
||||||
is_intermediate: false,
|
is_intermediate: false,
|
||||||
board_id: state.gallery.autoAddBoardId,
|
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
||||||
crop_visible: true,
|
crop_visible: true,
|
||||||
postUploadAction: {
|
postUploadAction: {
|
||||||
type: 'TOAST',
|
type: 'TOAST',
|
||||||
|
|||||||
@@ -31,15 +31,20 @@ const predicate: AnyListenerPredicate<RootState> = (
|
|||||||
// do not process if the user just disabled auto-config
|
// do not process if the user just disabled auto-config
|
||||||
if (
|
if (
|
||||||
prevState.controlNet.controlNets[action.payload.controlNetId]
|
prevState.controlNet.controlNets[action.payload.controlNetId]
|
||||||
.shouldAutoConfig === true
|
?.shouldAutoConfig === true
|
||||||
) {
|
) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const { controlImage, processorType, shouldAutoConfig } =
|
const cn = state.controlNet.controlNets[action.payload.controlNetId];
|
||||||
state.controlNet.controlNets[action.payload.controlNetId];
|
|
||||||
|
|
||||||
|
if (!cn) {
|
||||||
|
// something is wrong, the controlNet should exist
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const { controlImage, processorType, shouldAutoConfig } = cn;
|
||||||
if (controlNetModelChanged.match(action) && !shouldAutoConfig) {
|
if (controlNetModelChanged.match(action) && !shouldAutoConfig) {
|
||||||
// do not process if the action is a model change but the processor settings are dirty
|
// do not process if the action is a model change but the processor settings are dirty
|
||||||
return false;
|
return false;
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ export const addControlNetImageProcessedListener = () => {
|
|||||||
const { controlNetId } = action.payload;
|
const { controlNetId } = action.payload;
|
||||||
const controlNet = getState().controlNet.controlNets[controlNetId];
|
const controlNet = getState().controlNet.controlNets[controlNetId];
|
||||||
|
|
||||||
if (!controlNet.controlImage) {
|
if (!controlNet?.controlImage) {
|
||||||
log.error('Unable to process ControlNet image');
|
log.error('Unable to process ControlNet image');
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,57 +1,72 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import { resetCanvas } from 'features/canvas/store/canvasSlice';
|
import { resetCanvas } from 'features/canvas/store/canvasSlice';
|
||||||
import { controlNetReset } from 'features/controlNet/store/controlNetSlice';
|
import { controlNetReset } from 'features/controlNet/store/controlNetSlice';
|
||||||
|
import { imageDeletionConfirmed } from 'features/deleteImageModal/store/actions';
|
||||||
|
import { isModalOpenChanged } from 'features/deleteImageModal/store/slice';
|
||||||
import { selectListImagesBaseQueryArgs } from 'features/gallery/store/gallerySelectors';
|
import { selectListImagesBaseQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||||
import { imageDeletionConfirmed } from 'features/imageDeletion/store/actions';
|
|
||||||
import { isModalOpenChanged } from 'features/imageDeletion/store/imageDeletionSlice';
|
|
||||||
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
|
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
|
||||||
import { clearInitialImage } from 'features/parameters/store/generationSlice';
|
import { clearInitialImage } from 'features/parameters/store/generationSlice';
|
||||||
import { clamp } from 'lodash-es';
|
import { clamp } from 'lodash-es';
|
||||||
import { api } from 'services/api';
|
import { api } from 'services/api';
|
||||||
import { imagesApi } from 'services/api/endpoints/images';
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
|
import { imagesAdapter } from 'services/api/util';
|
||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
|
|
||||||
/**
|
export const addRequestedSingleImageDeletionListener = () => {
|
||||||
* Called when the user requests an image deletion
|
|
||||||
*/
|
|
||||||
export const addRequestedImageDeletionListener = () => {
|
|
||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: imageDeletionConfirmed,
|
actionCreator: imageDeletionConfirmed,
|
||||||
effect: async (action, { dispatch, getState, condition }) => {
|
effect: async (action, { dispatch, getState, condition }) => {
|
||||||
const { imageDTO, imageUsage } = action.payload;
|
const { imageDTOs, imagesUsage } = action.payload;
|
||||||
|
|
||||||
|
if (imageDTOs.length !== 1 || imagesUsage.length !== 1) {
|
||||||
|
// handle multiples in separate listener
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const imageDTO = imageDTOs[0];
|
||||||
|
const imageUsage = imagesUsage[0];
|
||||||
|
|
||||||
|
if (!imageDTO || !imageUsage) {
|
||||||
|
// satisfy noUncheckedIndexedAccess
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
dispatch(isModalOpenChanged(false));
|
dispatch(isModalOpenChanged(false));
|
||||||
|
|
||||||
const { image_name } = imageDTO;
|
|
||||||
|
|
||||||
const state = getState();
|
const state = getState();
|
||||||
const lastSelectedImage =
|
const lastSelectedImage =
|
||||||
state.gallery.selection[state.gallery.selection.length - 1];
|
state.gallery.selection[state.gallery.selection.length - 1]?.image_name;
|
||||||
|
|
||||||
|
if (imageDTO && imageDTO?.image_name === lastSelectedImage) {
|
||||||
|
const { image_name } = imageDTO;
|
||||||
|
|
||||||
if (lastSelectedImage === image_name) {
|
|
||||||
const baseQueryArgs = selectListImagesBaseQueryArgs(state);
|
const baseQueryArgs = selectListImagesBaseQueryArgs(state);
|
||||||
const { data } =
|
const { data } =
|
||||||
imagesApi.endpoints.listImages.select(baseQueryArgs)(state);
|
imagesApi.endpoints.listImages.select(baseQueryArgs)(state);
|
||||||
|
|
||||||
const ids = data?.ids ?? [];
|
const cachedImageDTOs = data
|
||||||
|
? imagesAdapter.getSelectors().selectAll(data)
|
||||||
|
: [];
|
||||||
|
|
||||||
const deletedImageIndex = ids.findIndex(
|
const deletedImageIndex = cachedImageDTOs.findIndex(
|
||||||
(result) => result.toString() === image_name
|
(i) => i.image_name === image_name
|
||||||
);
|
);
|
||||||
|
|
||||||
const filteredIds = ids.filter((id) => id.toString() !== image_name);
|
const filteredImageDTOs = cachedImageDTOs.filter(
|
||||||
|
(i) => i.image_name !== image_name
|
||||||
|
);
|
||||||
|
|
||||||
const newSelectedImageIndex = clamp(
|
const newSelectedImageIndex = clamp(
|
||||||
deletedImageIndex,
|
deletedImageIndex,
|
||||||
0,
|
0,
|
||||||
filteredIds.length - 1
|
filteredImageDTOs.length - 1
|
||||||
);
|
);
|
||||||
|
|
||||||
const newSelectedImageId = filteredIds[newSelectedImageIndex];
|
const newSelectedImageDTO = filteredImageDTOs[newSelectedImageIndex];
|
||||||
|
|
||||||
if (newSelectedImageId) {
|
if (newSelectedImageDTO) {
|
||||||
dispatch(imageSelected(newSelectedImageId as string));
|
dispatch(imageSelected(newSelectedImageDTO));
|
||||||
} else {
|
} else {
|
||||||
dispatch(imageSelected(null));
|
dispatch(imageSelected(null));
|
||||||
}
|
}
|
||||||
@@ -97,6 +112,66 @@ export const addRequestedImageDeletionListener = () => {
|
|||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Called when the user requests an image deletion
|
||||||
|
*/
|
||||||
|
export const addRequestedMultipleImageDeletionListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: imageDeletionConfirmed,
|
||||||
|
effect: async (action, { dispatch, getState }) => {
|
||||||
|
const { imageDTOs, imagesUsage } = action.payload;
|
||||||
|
|
||||||
|
if (imageDTOs.length < 1 || imagesUsage.length < 1) {
|
||||||
|
// handle singles in separate listener
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Delete from server
|
||||||
|
await dispatch(
|
||||||
|
imagesApi.endpoints.deleteImages.initiate({ imageDTOs })
|
||||||
|
).unwrap();
|
||||||
|
const state = getState();
|
||||||
|
const baseQueryArgs = selectListImagesBaseQueryArgs(state);
|
||||||
|
const { data } =
|
||||||
|
imagesApi.endpoints.listImages.select(baseQueryArgs)(state);
|
||||||
|
|
||||||
|
const newSelectedImageDTO = data
|
||||||
|
? imagesAdapter.getSelectors().selectAll(data)[0]
|
||||||
|
: undefined;
|
||||||
|
|
||||||
|
if (newSelectedImageDTO) {
|
||||||
|
dispatch(imageSelected(newSelectedImageDTO));
|
||||||
|
} else {
|
||||||
|
dispatch(imageSelected(null));
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(isModalOpenChanged(false));
|
||||||
|
|
||||||
|
// We need to reset the features where the image is in use - none of these work if their image(s) don't exist
|
||||||
|
|
||||||
|
if (imagesUsage.some((i) => i.isCanvasImage)) {
|
||||||
|
dispatch(resetCanvas());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (imagesUsage.some((i) => i.isControlNetImage)) {
|
||||||
|
dispatch(controlNetReset());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (imagesUsage.some((i) => i.isInitialImage)) {
|
||||||
|
dispatch(clearInitialImage());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (imagesUsage.some((i) => i.isNodesImage)) {
|
||||||
|
dispatch(nodeEditorReset());
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
// no-op
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Called when the actual delete request is sent to the server
|
* Called when the actual delete request is sent to the server
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -6,10 +6,7 @@ import {
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
||||||
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
|
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
|
||||||
import {
|
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||||
imageSelected,
|
|
||||||
imagesAddedToBatch,
|
|
||||||
} from 'features/gallery/store/gallerySlice';
|
|
||||||
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
||||||
import { imagesApi } from 'services/api/endpoints/images';
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
@@ -27,19 +24,32 @@ export const addImageDroppedListener = () => {
|
|||||||
const log = logger('images');
|
const log = logger('images');
|
||||||
const { activeData, overData } = action.payload;
|
const { activeData, overData } = action.payload;
|
||||||
|
|
||||||
log.debug({ activeData, overData }, 'Image or selection dropped');
|
if (activeData.payloadType === 'IMAGE_DTO') {
|
||||||
|
log.debug({ activeData, overData }, 'Image dropped');
|
||||||
|
} else if (activeData.payloadType === 'IMAGE_DTOS') {
|
||||||
|
log.debug(
|
||||||
|
{ activeData, overData },
|
||||||
|
`Images (${activeData.payload.imageDTOs.length}) dropped`
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
log.debug({ activeData, overData }, `Unknown payload dropped`);
|
||||||
|
}
|
||||||
|
|
||||||
// set current image
|
/**
|
||||||
|
* Image dropped on current image
|
||||||
|
*/
|
||||||
if (
|
if (
|
||||||
overData.actionType === 'SET_CURRENT_IMAGE' &&
|
overData.actionType === 'SET_CURRENT_IMAGE' &&
|
||||||
activeData.payloadType === 'IMAGE_DTO' &&
|
activeData.payloadType === 'IMAGE_DTO' &&
|
||||||
activeData.payload.imageDTO
|
activeData.payload.imageDTO
|
||||||
) {
|
) {
|
||||||
dispatch(imageSelected(activeData.payload.imageDTO.image_name));
|
dispatch(imageSelected(activeData.payload.imageDTO));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// set initial image
|
/**
|
||||||
|
* Image dropped on initial image
|
||||||
|
*/
|
||||||
if (
|
if (
|
||||||
overData.actionType === 'SET_INITIAL_IMAGE' &&
|
overData.actionType === 'SET_INITIAL_IMAGE' &&
|
||||||
activeData.payloadType === 'IMAGE_DTO' &&
|
activeData.payloadType === 'IMAGE_DTO' &&
|
||||||
@@ -49,27 +59,9 @@ export const addImageDroppedListener = () => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// add image to batch
|
/**
|
||||||
if (
|
* Image dropped on ControlNet
|
||||||
overData.actionType === 'ADD_TO_BATCH' &&
|
*/
|
||||||
activeData.payloadType === 'IMAGE_DTO' &&
|
|
||||||
activeData.payload.imageDTO
|
|
||||||
) {
|
|
||||||
dispatch(imagesAddedToBatch([activeData.payload.imageDTO.image_name]));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// add multiple images to batch
|
|
||||||
if (
|
|
||||||
overData.actionType === 'ADD_TO_BATCH' &&
|
|
||||||
activeData.payloadType === 'IMAGE_NAMES'
|
|
||||||
) {
|
|
||||||
dispatch(imagesAddedToBatch(activeData.payload.image_names));
|
|
||||||
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// set control image
|
|
||||||
if (
|
if (
|
||||||
overData.actionType === 'SET_CONTROLNET_IMAGE' &&
|
overData.actionType === 'SET_CONTROLNET_IMAGE' &&
|
||||||
activeData.payloadType === 'IMAGE_DTO' &&
|
activeData.payloadType === 'IMAGE_DTO' &&
|
||||||
@@ -85,7 +77,9 @@ export const addImageDroppedListener = () => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// set canvas image
|
/**
|
||||||
|
* Image dropped on Canvas
|
||||||
|
*/
|
||||||
if (
|
if (
|
||||||
overData.actionType === 'SET_CANVAS_INITIAL_IMAGE' &&
|
overData.actionType === 'SET_CANVAS_INITIAL_IMAGE' &&
|
||||||
activeData.payloadType === 'IMAGE_DTO' &&
|
activeData.payloadType === 'IMAGE_DTO' &&
|
||||||
@@ -95,7 +89,9 @@ export const addImageDroppedListener = () => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// set nodes image
|
/**
|
||||||
|
* Image dropped on node image field
|
||||||
|
*/
|
||||||
if (
|
if (
|
||||||
overData.actionType === 'SET_NODES_IMAGE' &&
|
overData.actionType === 'SET_NODES_IMAGE' &&
|
||||||
activeData.payloadType === 'IMAGE_DTO' &&
|
activeData.payloadType === 'IMAGE_DTO' &&
|
||||||
@@ -112,61 +108,36 @@ export const addImageDroppedListener = () => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// set multiple nodes images (single image handler)
|
/**
|
||||||
if (
|
* TODO
|
||||||
overData.actionType === 'SET_MULTI_NODES_IMAGE' &&
|
* Image selection dropped on node image collection field
|
||||||
activeData.payloadType === 'IMAGE_DTO' &&
|
*/
|
||||||
activeData.payload.imageDTO
|
|
||||||
) {
|
|
||||||
const { fieldName, nodeId } = overData.context;
|
|
||||||
dispatch(
|
|
||||||
fieldValueChanged({
|
|
||||||
nodeId,
|
|
||||||
fieldName,
|
|
||||||
value: [activeData.payload.imageDTO],
|
|
||||||
})
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// // set multiple nodes images (multiple images handler)
|
|
||||||
// if (
|
// if (
|
||||||
// overData.actionType === 'SET_MULTI_NODES_IMAGE' &&
|
// overData.actionType === 'SET_MULTI_NODES_IMAGE' &&
|
||||||
// activeData.payloadType === 'IMAGE_NAMES'
|
// activeData.payloadType === 'IMAGE_DTO' &&
|
||||||
|
// activeData.payload.imageDTO
|
||||||
// ) {
|
// ) {
|
||||||
// const { fieldName, nodeId } = overData.context;
|
// const { fieldName, nodeId } = overData.context;
|
||||||
// dispatch(
|
// dispatch(
|
||||||
// imageCollectionFieldValueChanged({
|
// fieldValueChanged({
|
||||||
// nodeId,
|
// nodeId,
|
||||||
// fieldName,
|
// fieldName,
|
||||||
// value: activeData.payload.image_names.map((image_name) => ({
|
// value: [activeData.payload.imageDTO],
|
||||||
// image_name,
|
|
||||||
// })),
|
|
||||||
// })
|
// })
|
||||||
// );
|
// );
|
||||||
// return;
|
// return;
|
||||||
// }
|
// }
|
||||||
|
|
||||||
// add image to board
|
/**
|
||||||
|
* Image dropped on user board
|
||||||
|
*/
|
||||||
if (
|
if (
|
||||||
overData.actionType === 'MOVE_BOARD' &&
|
overData.actionType === 'ADD_TO_BOARD' &&
|
||||||
activeData.payloadType === 'IMAGE_DTO' &&
|
activeData.payloadType === 'IMAGE_DTO' &&
|
||||||
activeData.payload.imageDTO
|
activeData.payload.imageDTO
|
||||||
) {
|
) {
|
||||||
const { imageDTO } = activeData.payload;
|
const { imageDTO } = activeData.payload;
|
||||||
const { boardId } = overData.context;
|
const { boardId } = overData.context;
|
||||||
|
|
||||||
// image was droppe on the "NoBoardBoard"
|
|
||||||
if (!boardId) {
|
|
||||||
dispatch(
|
|
||||||
imagesApi.endpoints.removeImageFromBoard.initiate({
|
|
||||||
imageDTO,
|
|
||||||
})
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// image was dropped on a user board
|
|
||||||
dispatch(
|
dispatch(
|
||||||
imagesApi.endpoints.addImageToBoard.initiate({
|
imagesApi.endpoints.addImageToBoard.initiate({
|
||||||
imageDTO,
|
imageDTO,
|
||||||
@@ -176,67 +147,58 @@ export const addImageDroppedListener = () => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// // add gallery selection to board
|
/**
|
||||||
// if (
|
* Image dropped on 'none' board
|
||||||
// overData.actionType === 'MOVE_BOARD' &&
|
*/
|
||||||
// activeData.payloadType === 'IMAGE_NAMES' &&
|
if (
|
||||||
// overData.context.boardId
|
overData.actionType === 'REMOVE_FROM_BOARD' &&
|
||||||
// ) {
|
activeData.payloadType === 'IMAGE_DTO' &&
|
||||||
// console.log('adding gallery selection to board');
|
activeData.payload.imageDTO
|
||||||
// const board_id = overData.context.boardId;
|
) {
|
||||||
// dispatch(
|
const { imageDTO } = activeData.payload;
|
||||||
// boardImagesApi.endpoints.addManyBoardImages.initiate({
|
dispatch(
|
||||||
// board_id,
|
imagesApi.endpoints.removeImageFromBoard.initiate({
|
||||||
// image_names: activeData.payload.image_names,
|
imageDTO,
|
||||||
// })
|
})
|
||||||
// );
|
);
|
||||||
// return;
|
return;
|
||||||
// }
|
}
|
||||||
|
|
||||||
// // remove gallery selection from board
|
/**
|
||||||
// if (
|
* Multiple images dropped on user board
|
||||||
// overData.actionType === 'MOVE_BOARD' &&
|
*/
|
||||||
// activeData.payloadType === 'IMAGE_NAMES' &&
|
if (
|
||||||
// overData.context.boardId === null
|
overData.actionType === 'ADD_TO_BOARD' &&
|
||||||
// ) {
|
activeData.payloadType === 'IMAGE_DTOS' &&
|
||||||
// console.log('removing gallery selection to board');
|
activeData.payload.imageDTOs
|
||||||
// dispatch(
|
) {
|
||||||
// boardImagesApi.endpoints.deleteManyBoardImages.initiate({
|
const { imageDTOs } = activeData.payload;
|
||||||
// image_names: activeData.payload.image_names,
|
const { boardId } = overData.context;
|
||||||
// })
|
dispatch(
|
||||||
// );
|
imagesApi.endpoints.addImagesToBoard.initiate({
|
||||||
// return;
|
imageDTOs,
|
||||||
// }
|
board_id: boardId,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// // add batch selection to board
|
/**
|
||||||
// if (
|
* Multiple images dropped on 'none' board
|
||||||
// overData.actionType === 'MOVE_BOARD' &&
|
*/
|
||||||
// activeData.payloadType === 'IMAGE_NAMES' &&
|
if (
|
||||||
// overData.context.boardId
|
overData.actionType === 'REMOVE_FROM_BOARD' &&
|
||||||
// ) {
|
activeData.payloadType === 'IMAGE_DTOS' &&
|
||||||
// const board_id = overData.context.boardId;
|
activeData.payload.imageDTOs
|
||||||
// dispatch(
|
) {
|
||||||
// boardImagesApi.endpoints.addManyBoardImages.initiate({
|
const { imageDTOs } = activeData.payload;
|
||||||
// board_id,
|
dispatch(
|
||||||
// image_names: activeData.payload.image_names,
|
imagesApi.endpoints.removeImagesFromBoard.initiate({
|
||||||
// })
|
imageDTOs,
|
||||||
// );
|
})
|
||||||
// return;
|
);
|
||||||
// }
|
return;
|
||||||
|
}
|
||||||
// // remove batch selection from board
|
|
||||||
// if (
|
|
||||||
// overData.actionType === 'MOVE_BOARD' &&
|
|
||||||
// activeData.payloadType === 'IMAGE_NAMES' &&
|
|
||||||
// overData.context.boardId === null
|
|
||||||
// ) {
|
|
||||||
// dispatch(
|
|
||||||
// boardImagesApi.endpoints.deleteManyBoardImages.initiate({
|
|
||||||
// image_names: activeData.payload.image_names,
|
|
||||||
// })
|
|
||||||
// );
|
|
||||||
// return;
|
|
||||||
// }
|
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -1,37 +1,32 @@
|
|||||||
import { imageDeletionConfirmed } from 'features/imageDeletion/store/actions';
|
import { imageDeletionConfirmed } from 'features/deleteImageModal/store/actions';
|
||||||
import { selectImageUsage } from 'features/imageDeletion/store/imageDeletionSelectors';
|
import { selectImageUsage } from 'features/deleteImageModal/store/selectors';
|
||||||
import {
|
import {
|
||||||
imageToDeleteSelected,
|
imagesToDeleteSelected,
|
||||||
isModalOpenChanged,
|
isModalOpenChanged,
|
||||||
} from 'features/imageDeletion/store/imageDeletionSlice';
|
} from 'features/deleteImageModal/store/slice';
|
||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
|
|
||||||
export const addImageToDeleteSelectedListener = () => {
|
export const addImageToDeleteSelectedListener = () => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: imageToDeleteSelected,
|
actionCreator: imagesToDeleteSelected,
|
||||||
effect: async (action, { dispatch, getState }) => {
|
effect: async (action, { dispatch, getState }) => {
|
||||||
const imageDTO = action.payload;
|
const imageDTOs = action.payload;
|
||||||
const state = getState();
|
const state = getState();
|
||||||
const { shouldConfirmOnDelete } = state.system;
|
const { shouldConfirmOnDelete } = state.system;
|
||||||
const imageUsage = selectImageUsage(getState());
|
const imagesUsage = selectImageUsage(getState());
|
||||||
|
|
||||||
if (!imageUsage) {
|
|
||||||
// should never happen
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const isImageInUse =
|
const isImageInUse =
|
||||||
imageUsage.isCanvasImage ||
|
imagesUsage.some((i) => i.isCanvasImage) ||
|
||||||
imageUsage.isInitialImage ||
|
imagesUsage.some((i) => i.isInitialImage) ||
|
||||||
imageUsage.isControlNetImage ||
|
imagesUsage.some((i) => i.isControlNetImage) ||
|
||||||
imageUsage.isNodesImage;
|
imagesUsage.some((i) => i.isNodesImage);
|
||||||
|
|
||||||
if (shouldConfirmOnDelete || isImageInUse) {
|
if (shouldConfirmOnDelete || isImageInUse) {
|
||||||
dispatch(isModalOpenChanged(true));
|
dispatch(isModalOpenChanged(true));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
dispatch(imageDeletionConfirmed({ imageDTO, imageUsage }));
|
dispatch(imageDeletionConfirmed({ imageDTOs, imagesUsage }));
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -2,14 +2,13 @@ import { UseToastOptions } from '@chakra-ui/react';
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
||||||
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
|
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
|
||||||
import { imagesAddedToBatch } from 'features/gallery/store/gallerySlice';
|
|
||||||
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
import { omit } from 'lodash-es';
|
||||||
import { boardsApi } from 'services/api/endpoints/boards';
|
import { boardsApi } from 'services/api/endpoints/boards';
|
||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
import { imagesApi } from '../../../../../services/api/endpoints/images';
|
import { imagesApi } from '../../../../../services/api/endpoints/images';
|
||||||
import { omit } from 'lodash-es';
|
|
||||||
|
|
||||||
const DEFAULT_UPLOADED_TOAST: UseToastOptions = {
|
const DEFAULT_UPLOADED_TOAST: UseToastOptions = {
|
||||||
title: 'Image Uploaded',
|
title: 'Image Uploaded',
|
||||||
@@ -41,7 +40,7 @@ export const addImageUploadedFulfilledListener = () => {
|
|||||||
// default action - just upload and alert user
|
// default action - just upload and alert user
|
||||||
if (postUploadAction?.type === 'TOAST') {
|
if (postUploadAction?.type === 'TOAST') {
|
||||||
const { toastOptions } = postUploadAction;
|
const { toastOptions } = postUploadAction;
|
||||||
if (!autoAddBoardId) {
|
if (!autoAddBoardId || autoAddBoardId === 'none') {
|
||||||
dispatch(addToast({ ...DEFAULT_UPLOADED_TOAST, ...toastOptions }));
|
dispatch(addToast({ ...DEFAULT_UPLOADED_TOAST, ...toastOptions }));
|
||||||
} else {
|
} else {
|
||||||
// Add this image to the board
|
// Add this image to the board
|
||||||
@@ -121,17 +120,6 @@ export const addImageUploadedFulfilledListener = () => {
|
|||||||
);
|
);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (postUploadAction?.type === 'ADD_TO_BATCH') {
|
|
||||||
dispatch(imagesAddedToBatch([imageDTO.image_name]));
|
|
||||||
dispatch(
|
|
||||||
addToast({
|
|
||||||
...DEFAULT_UPLOADED_TOAST,
|
|
||||||
description: 'Added to batch',
|
|
||||||
})
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ import {
|
|||||||
setShouldUseSDXLRefiner,
|
setShouldUseSDXLRefiner,
|
||||||
} from 'features/sdxl/store/sdxlSlice';
|
} from 'features/sdxl/store/sdxlSlice';
|
||||||
import { forEach, some } from 'lodash-es';
|
import { forEach, some } from 'lodash-es';
|
||||||
import { modelsApi } from 'services/api/endpoints/models';
|
import { modelsApi, vaeModelsAdapter } from 'services/api/endpoints/models';
|
||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
|
|
||||||
export const addModelsLoadedListener = () => {
|
export const addModelsLoadedListener = () => {
|
||||||
@@ -144,8 +144,9 @@ export const addModelsLoadedListener = () => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const firstModelId = action.payload.ids[0];
|
const firstModel = vaeModelsAdapter
|
||||||
const firstModel = action.payload.entities[firstModelId];
|
.getSelectors()
|
||||||
|
.selectAll(action.payload)[0];
|
||||||
|
|
||||||
if (!firstModel) {
|
if (!firstModel) {
|
||||||
// No custom VAEs loaded at all; use the default
|
// No custom VAEs loaded at all; use the default
|
||||||
|
|||||||
@@ -8,9 +8,10 @@ import {
|
|||||||
} from 'features/gallery/store/gallerySlice';
|
} from 'features/gallery/store/gallerySlice';
|
||||||
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
||||||
import { progressImageSet } from 'features/system/store/systemSlice';
|
import { progressImageSet } from 'features/system/store/systemSlice';
|
||||||
import { imagesAdapter, imagesApi } from 'services/api/endpoints/images';
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
import { isImageOutput } from 'services/api/guards';
|
import { isImageOutput } from 'services/api/guards';
|
||||||
import { sessionCanceled } from 'services/api/thunks/session';
|
import { sessionCanceled } from 'services/api/thunks/session';
|
||||||
|
import { imagesAdapter } from 'services/api/util';
|
||||||
import {
|
import {
|
||||||
appSocketInvocationComplete,
|
appSocketInvocationComplete,
|
||||||
socketInvocationComplete,
|
socketInvocationComplete,
|
||||||
@@ -67,7 +68,7 @@ export const addInvocationCompleteEventListener = () => {
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
const { autoAddBoardId } = gallery;
|
const { autoAddBoardId } = gallery;
|
||||||
if (autoAddBoardId) {
|
if (autoAddBoardId && autoAddBoardId !== 'none') {
|
||||||
dispatch(
|
dispatch(
|
||||||
imagesApi.endpoints.addImageToBoard.initiate({
|
imagesApi.endpoints.addImageToBoard.initiate({
|
||||||
board_id: autoAddBoardId,
|
board_id: autoAddBoardId,
|
||||||
@@ -83,10 +84,7 @@ export const addInvocationCompleteEventListener = () => {
|
|||||||
categories: IMAGE_CATEGORIES,
|
categories: IMAGE_CATEGORIES,
|
||||||
},
|
},
|
||||||
(draft) => {
|
(draft) => {
|
||||||
const oldTotal = draft.total;
|
imagesAdapter.addOne(draft, imageDTO);
|
||||||
const newState = imagesAdapter.addOne(draft, imageDTO);
|
|
||||||
const delta = newState.total - oldTotal;
|
|
||||||
draft.total = draft.total + delta;
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
@@ -94,8 +92,8 @@ export const addInvocationCompleteEventListener = () => {
|
|||||||
|
|
||||||
dispatch(
|
dispatch(
|
||||||
imagesApi.util.invalidateTags([
|
imagesApi.util.invalidateTags([
|
||||||
{ type: 'BoardImagesTotal', id: autoAddBoardId ?? 'none' },
|
{ type: 'BoardImagesTotal', id: autoAddBoardId },
|
||||||
{ type: 'BoardAssetsTotal', id: autoAddBoardId ?? 'none' },
|
{ type: 'BoardAssetsTotal', id: autoAddBoardId },
|
||||||
])
|
])
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -110,7 +108,7 @@ export const addInvocationCompleteEventListener = () => {
|
|||||||
} else if (!autoAddBoardId) {
|
} else if (!autoAddBoardId) {
|
||||||
dispatch(galleryViewChanged('images'));
|
dispatch(galleryViewChanged('images'));
|
||||||
}
|
}
|
||||||
dispatch(imageSelected(imageDTO.image_name));
|
dispatch(imageSelected(imageDTO));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -8,9 +8,9 @@ import {
|
|||||||
import canvasReducer from 'features/canvas/store/canvasSlice';
|
import canvasReducer from 'features/canvas/store/canvasSlice';
|
||||||
import controlNetReducer from 'features/controlNet/store/controlNetSlice';
|
import controlNetReducer from 'features/controlNet/store/controlNetSlice';
|
||||||
import dynamicPromptsReducer from 'features/dynamicPrompts/store/dynamicPromptsSlice';
|
import dynamicPromptsReducer from 'features/dynamicPrompts/store/dynamicPromptsSlice';
|
||||||
import boardsReducer from 'features/gallery/store/boardSlice';
|
|
||||||
import galleryReducer from 'features/gallery/store/gallerySlice';
|
import galleryReducer from 'features/gallery/store/gallerySlice';
|
||||||
import imageDeletionReducer from 'features/imageDeletion/store/imageDeletionSlice';
|
import deleteImageModalReducer from 'features/deleteImageModal/store/slice';
|
||||||
|
import changeBoardModalReducer from 'features/changeBoardModal/store/slice';
|
||||||
import loraReducer from 'features/lora/store/loraSlice';
|
import loraReducer from 'features/lora/store/loraSlice';
|
||||||
import nodesReducer from 'features/nodes/store/nodesSlice';
|
import nodesReducer from 'features/nodes/store/nodesSlice';
|
||||||
import generationReducer from 'features/parameters/store/generationSlice';
|
import generationReducer from 'features/parameters/store/generationSlice';
|
||||||
@@ -43,9 +43,9 @@ const allReducers = {
|
|||||||
ui: uiReducer,
|
ui: uiReducer,
|
||||||
hotkeys: hotkeysReducer,
|
hotkeys: hotkeysReducer,
|
||||||
controlNet: controlNetReducer,
|
controlNet: controlNetReducer,
|
||||||
boards: boardsReducer,
|
|
||||||
dynamicPrompts: dynamicPromptsReducer,
|
dynamicPrompts: dynamicPromptsReducer,
|
||||||
imageDeletion: imageDeletionReducer,
|
deleteImageModal: deleteImageModalReducer,
|
||||||
|
changeBoardModal: changeBoardModalReducer,
|
||||||
lora: loraReducer,
|
lora: loraReducer,
|
||||||
modelmanager: modelmanagerReducer,
|
modelmanager: modelmanagerReducer,
|
||||||
sdxl: sdxlReducer,
|
sdxl: sdxlReducer,
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import { Flex, Text, useColorMode } from '@chakra-ui/react';
|
import { Box, Flex, useColorMode } from '@chakra-ui/react';
|
||||||
import { motion } from 'framer-motion';
|
import { motion } from 'framer-motion';
|
||||||
import { ReactNode, memo, useRef } from 'react';
|
import { ReactNode, memo, useRef } from 'react';
|
||||||
import { mode } from 'theme/util/mode';
|
import { mode } from 'theme/util/mode';
|
||||||
@@ -74,7 +74,7 @@ export const IAIDropOverlay = (props: Props) => {
|
|||||||
justifyContent: 'center',
|
justifyContent: 'center',
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<Text
|
<Box
|
||||||
sx={{
|
sx={{
|
||||||
fontSize: '2xl',
|
fontSize: '2xl',
|
||||||
fontWeight: 600,
|
fontWeight: 600,
|
||||||
@@ -87,7 +87,7 @@ export const IAIDropOverlay = (props: Props) => {
|
|||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
{label}
|
{label}
|
||||||
</Text>
|
</Box>
|
||||||
</Flex>
|
</Flex>
|
||||||
</Flex>
|
</Flex>
|
||||||
</motion.div>
|
</motion.div>
|
||||||
|
|||||||
@@ -53,7 +53,9 @@ const IAIMantineSearchableSelect = (props: IAISelectProps) => {
|
|||||||
// wrap onChange to clear search value on select
|
// wrap onChange to clear search value on select
|
||||||
const handleChange = useCallback(
|
const handleChange = useCallback(
|
||||||
(v: string | null) => {
|
(v: string | null) => {
|
||||||
setSearchValue('');
|
// cannot figure out why we were doing this, but it was causing an issue where if you
|
||||||
|
// select the currently-selected item, it reset the search value to empty
|
||||||
|
// setSearchValue('');
|
||||||
|
|
||||||
if (!onChange) {
|
if (!onChange) {
|
||||||
return;
|
return;
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ const ImageUploader = (props: ImageUploaderProps) => {
|
|||||||
image_category: 'user',
|
image_category: 'user',
|
||||||
is_intermediate: false,
|
is_intermediate: false,
|
||||||
postUploadAction,
|
postUploadAction,
|
||||||
board_id: autoAddBoardId,
|
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
[autoAddBoardId, postUploadAction, uploadImage]
|
[autoAddBoardId, postUploadAction, uploadImage]
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ export const useImageUploadButton = ({
|
|||||||
image_category: 'user',
|
image_category: 'user',
|
||||||
is_intermediate: false,
|
is_intermediate: false,
|
||||||
postUploadAction: postUploadAction ?? { type: 'TOAST' },
|
postUploadAction: postUploadAction ?? { type: 'TOAST' },
|
||||||
board_id: autoAddBoardId,
|
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
[autoAddBoardId, postUploadAction, uploadImage]
|
[autoAddBoardId, postUploadAction, uploadImage]
|
||||||
|
|||||||
@@ -33,6 +33,10 @@ const useColorPicker = () => {
|
|||||||
1
|
1
|
||||||
).data;
|
).data;
|
||||||
|
|
||||||
|
if (!(a && r && g && b)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
dispatch(setColorPickerColor({ r, g, b, a }));
|
dispatch(setColorPickerColor({ r, g, b, a }));
|
||||||
},
|
},
|
||||||
commitColorUnderCursor: () => {
|
commitColorUnderCursor: () => {
|
||||||
|
|||||||
@@ -727,10 +727,13 @@ export const canvasSlice = createSlice({
|
|||||||
state.pastLayerStates.shift();
|
state.pastLayerStates.shift();
|
||||||
}
|
}
|
||||||
|
|
||||||
state.layerState.objects.push({
|
const imageToCommit = images[selectedImageIndex];
|
||||||
...images[selectedImageIndex],
|
|
||||||
});
|
|
||||||
|
|
||||||
|
if (imageToCommit) {
|
||||||
|
state.layerState.objects.push({
|
||||||
|
...imageToCommit,
|
||||||
|
});
|
||||||
|
}
|
||||||
state.layerState.stagingArea = {
|
state.layerState.stagingArea = {
|
||||||
...initialLayerState.stagingArea,
|
...initialLayerState.stagingArea,
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -0,0 +1,132 @@
|
|||||||
|
import {
|
||||||
|
AlertDialog,
|
||||||
|
AlertDialogBody,
|
||||||
|
AlertDialogContent,
|
||||||
|
AlertDialogFooter,
|
||||||
|
AlertDialogHeader,
|
||||||
|
AlertDialogOverlay,
|
||||||
|
Flex,
|
||||||
|
Text,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import IAIButton from 'common/components/IAIButton';
|
||||||
|
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
||||||
|
import { memo, useCallback, useMemo, useRef, useState } from 'react';
|
||||||
|
import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
|
||||||
|
import {
|
||||||
|
useAddImagesToBoardMutation,
|
||||||
|
useRemoveImagesFromBoardMutation,
|
||||||
|
} from 'services/api/endpoints/images';
|
||||||
|
import { changeBoardReset, isModalOpenChanged } from '../store/slice';
|
||||||
|
|
||||||
|
const selector = createSelector(
|
||||||
|
[stateSelector],
|
||||||
|
({ changeBoardModal }) => {
|
||||||
|
const { isModalOpen, imagesToChange } = changeBoardModal;
|
||||||
|
|
||||||
|
return {
|
||||||
|
isModalOpen,
|
||||||
|
imagesToChange,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
|
const ChangeBoardModal = () => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const [selectedBoard, setSelectedBoard] = useState<string | null>();
|
||||||
|
const { data: boards, isFetching } = useListAllBoardsQuery();
|
||||||
|
const { imagesToChange, isModalOpen } = useAppSelector(selector);
|
||||||
|
const [addImagesToBoard] = useAddImagesToBoardMutation();
|
||||||
|
const [removeImagesFromBoard] = useRemoveImagesFromBoardMutation();
|
||||||
|
|
||||||
|
const data = useMemo(() => {
|
||||||
|
const data: { label: string; value: string }[] = [
|
||||||
|
{ label: 'Uncategorized', value: 'none' },
|
||||||
|
];
|
||||||
|
(boards ?? []).forEach((board) =>
|
||||||
|
data.push({
|
||||||
|
label: board.board_name,
|
||||||
|
value: board.board_id,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}, [boards]);
|
||||||
|
|
||||||
|
const handleClose = useCallback(() => {
|
||||||
|
dispatch(changeBoardReset());
|
||||||
|
dispatch(isModalOpenChanged(false));
|
||||||
|
}, [dispatch]);
|
||||||
|
|
||||||
|
const handleChangeBoard = useCallback(() => {
|
||||||
|
if (!imagesToChange.length || !selectedBoard) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (selectedBoard === 'none') {
|
||||||
|
removeImagesFromBoard({ imageDTOs: imagesToChange });
|
||||||
|
} else {
|
||||||
|
addImagesToBoard({
|
||||||
|
imageDTOs: imagesToChange,
|
||||||
|
board_id: selectedBoard,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
setSelectedBoard(null);
|
||||||
|
dispatch(changeBoardReset());
|
||||||
|
}, [
|
||||||
|
addImagesToBoard,
|
||||||
|
dispatch,
|
||||||
|
imagesToChange,
|
||||||
|
removeImagesFromBoard,
|
||||||
|
selectedBoard,
|
||||||
|
]);
|
||||||
|
|
||||||
|
const cancelRef = useRef<HTMLButtonElement>(null);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<AlertDialog
|
||||||
|
isOpen={isModalOpen}
|
||||||
|
onClose={handleClose}
|
||||||
|
leastDestructiveRef={cancelRef}
|
||||||
|
isCentered
|
||||||
|
>
|
||||||
|
<AlertDialogOverlay>
|
||||||
|
<AlertDialogContent>
|
||||||
|
<AlertDialogHeader fontSize="lg" fontWeight="bold">
|
||||||
|
Change Board
|
||||||
|
</AlertDialogHeader>
|
||||||
|
|
||||||
|
<AlertDialogBody>
|
||||||
|
<Flex sx={{ flexDir: 'column', gap: 4 }}>
|
||||||
|
<Text>
|
||||||
|
Moving {`${imagesToChange.length}`} image
|
||||||
|
{`${imagesToChange.length > 1 ? 's' : ''}`} to board:
|
||||||
|
</Text>
|
||||||
|
<IAIMantineSearchableSelect
|
||||||
|
placeholder={isFetching ? 'Loading...' : 'Select Board'}
|
||||||
|
disabled={isFetching}
|
||||||
|
onChange={(v) => setSelectedBoard(v)}
|
||||||
|
value={selectedBoard}
|
||||||
|
data={data}
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
</AlertDialogBody>
|
||||||
|
<AlertDialogFooter>
|
||||||
|
<IAIButton ref={cancelRef} onClick={handleClose}>
|
||||||
|
Cancel
|
||||||
|
</IAIButton>
|
||||||
|
<IAIButton colorScheme="accent" onClick={handleChangeBoard} ml={3}>
|
||||||
|
Move
|
||||||
|
</IAIButton>
|
||||||
|
</AlertDialogFooter>
|
||||||
|
</AlertDialogContent>
|
||||||
|
</AlertDialogOverlay>
|
||||||
|
</AlertDialog>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(ChangeBoardModal);
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
import { ChangeBoardModalState } from './types';
|
||||||
|
|
||||||
|
export const initialState: ChangeBoardModalState = {
|
||||||
|
isModalOpen: false,
|
||||||
|
imagesToChange: [],
|
||||||
|
};
|
||||||
@@ -0,0 +1,25 @@
|
|||||||
|
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
|
||||||
|
import { ImageDTO } from 'services/api/types';
|
||||||
|
import { initialState } from './initialState';
|
||||||
|
|
||||||
|
const changeBoardModal = createSlice({
|
||||||
|
name: 'changeBoardModal',
|
||||||
|
initialState,
|
||||||
|
reducers: {
|
||||||
|
isModalOpenChanged: (state, action: PayloadAction<boolean>) => {
|
||||||
|
state.isModalOpen = action.payload;
|
||||||
|
},
|
||||||
|
imagesToChangeSelected: (state, action: PayloadAction<ImageDTO[]>) => {
|
||||||
|
state.imagesToChange = action.payload;
|
||||||
|
},
|
||||||
|
changeBoardReset: (state) => {
|
||||||
|
state.imagesToChange = [];
|
||||||
|
state.isModalOpen = false;
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
export const { isModalOpenChanged, imagesToChangeSelected, changeBoardReset } =
|
||||||
|
changeBoardModal.actions;
|
||||||
|
|
||||||
|
export default changeBoardModal.reducer;
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
import { ImageDTO } from 'services/api/types';
|
||||||
|
|
||||||
|
export type ChangeBoardModalState = {
|
||||||
|
isModalOpen: boolean;
|
||||||
|
imagesToChange: ImageDTO[];
|
||||||
|
};
|
||||||
@@ -3,6 +3,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { FaCopy, FaTrash } from 'react-icons/fa';
|
import { FaCopy, FaTrash } from 'react-icons/fa';
|
||||||
import {
|
import {
|
||||||
|
ControlNetConfig,
|
||||||
controlNetDuplicated,
|
controlNetDuplicated,
|
||||||
controlNetRemoved,
|
controlNetRemoved,
|
||||||
controlNetToggled,
|
controlNetToggled,
|
||||||
@@ -27,18 +28,27 @@ import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcesso
|
|||||||
import ParamControlNetResizeMode from './parameters/ParamControlNetResizeMode';
|
import ParamControlNetResizeMode from './parameters/ParamControlNetResizeMode';
|
||||||
|
|
||||||
type ControlNetProps = {
|
type ControlNetProps = {
|
||||||
controlNetId: string;
|
controlNet: ControlNetConfig;
|
||||||
};
|
};
|
||||||
|
|
||||||
const ControlNet = (props: ControlNetProps) => {
|
const ControlNet = (props: ControlNetProps) => {
|
||||||
const { controlNetId } = props;
|
const { controlNet } = props;
|
||||||
|
const { controlNetId } = controlNet;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
stateSelector,
|
stateSelector,
|
||||||
({ controlNet }) => {
|
({ controlNet }) => {
|
||||||
const { isEnabled, shouldAutoConfig } =
|
const cn = controlNet.controlNets[controlNetId];
|
||||||
controlNet.controlNets[controlNetId];
|
|
||||||
|
if (!cn) {
|
||||||
|
return {
|
||||||
|
isEnabled: false,
|
||||||
|
shouldAutoConfig: false,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
const { isEnabled, shouldAutoConfig } = cn;
|
||||||
|
|
||||||
return { isEnabled, shouldAutoConfig };
|
return { isEnabled, shouldAutoConfig };
|
||||||
},
|
},
|
||||||
@@ -96,7 +106,7 @@ const ControlNet = (props: ControlNetProps) => {
|
|||||||
transitionDuration: '0.1s',
|
transitionDuration: '0.1s',
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<ParamControlNetModel controlNetId={controlNetId} />
|
<ParamControlNetModel controlNet={controlNet} />
|
||||||
</Box>
|
</Box>
|
||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
size="sm"
|
size="sm"
|
||||||
@@ -171,8 +181,8 @@ const ControlNet = (props: ControlNetProps) => {
|
|||||||
justifyContent: 'space-between',
|
justifyContent: 'space-between',
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<ParamControlNetWeight controlNetId={controlNetId} />
|
<ParamControlNetWeight controlNet={controlNet} />
|
||||||
<ParamControlNetBeginEnd controlNetId={controlNetId} />
|
<ParamControlNetBeginEnd controlNet={controlNet} />
|
||||||
</Flex>
|
</Flex>
|
||||||
{!isExpanded && (
|
{!isExpanded && (
|
||||||
<Flex
|
<Flex
|
||||||
@@ -184,22 +194,22 @@ const ControlNet = (props: ControlNetProps) => {
|
|||||||
aspectRatio: '1/1',
|
aspectRatio: '1/1',
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<ControlNetImagePreview controlNetId={controlNetId} height={28} />
|
<ControlNetImagePreview controlNet={controlNet} height={28} />
|
||||||
</Flex>
|
</Flex>
|
||||||
)}
|
)}
|
||||||
</Flex>
|
</Flex>
|
||||||
<Flex sx={{ gap: 2 }}>
|
<Flex sx={{ gap: 2 }}>
|
||||||
<ParamControlNetControlMode controlNetId={controlNetId} />
|
<ParamControlNetControlMode controlNet={controlNet} />
|
||||||
<ParamControlNetResizeMode controlNetId={controlNetId} />
|
<ParamControlNetResizeMode controlNet={controlNet} />
|
||||||
</Flex>
|
</Flex>
|
||||||
<ParamControlNetProcessorSelect controlNetId={controlNetId} />
|
<ParamControlNetProcessorSelect controlNet={controlNet} />
|
||||||
</Flex>
|
</Flex>
|
||||||
|
|
||||||
{isExpanded && (
|
{isExpanded && (
|
||||||
<>
|
<>
|
||||||
<ControlNetImagePreview controlNetId={controlNetId} height="392px" />
|
<ControlNetImagePreview controlNet={controlNet} height="392px" />
|
||||||
<ParamControlNetShouldAutoConfig controlNetId={controlNetId} />
|
<ParamControlNetShouldAutoConfig controlNet={controlNet} />
|
||||||
<ControlNetProcessorComponent controlNetId={controlNetId} />
|
<ControlNetProcessorComponent controlNet={controlNet} />
|
||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
</Flex>
|
</Flex>
|
||||||
|
|||||||
@@ -12,50 +12,41 @@ import IAIDndImage from 'common/components/IAIDndImage';
|
|||||||
import { memo, useCallback, useMemo, useState } from 'react';
|
import { memo, useCallback, useMemo, useState } from 'react';
|
||||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||||
import { PostUploadAction } from 'services/api/types';
|
import { PostUploadAction } from 'services/api/types';
|
||||||
import { controlNetImageChanged } from '../store/controlNetSlice';
|
import {
|
||||||
|
ControlNetConfig,
|
||||||
|
controlNetImageChanged,
|
||||||
|
} from '../store/controlNetSlice';
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
controlNetId: string;
|
controlNet: ControlNetConfig;
|
||||||
height: SystemStyleObject['h'];
|
height: SystemStyleObject['h'];
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const selector = createSelector(
|
||||||
|
stateSelector,
|
||||||
|
({ controlNet }) => {
|
||||||
|
const { pendingControlImages } = controlNet;
|
||||||
|
|
||||||
|
return {
|
||||||
|
pendingControlImages,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
const ControlNetImagePreview = (props: Props) => {
|
const ControlNetImagePreview = (props: Props) => {
|
||||||
const { height, controlNetId } = props;
|
const { height } = props;
|
||||||
|
const {
|
||||||
|
controlImage: controlImageName,
|
||||||
|
processedControlImage: processedControlImageName,
|
||||||
|
processorType,
|
||||||
|
isEnabled,
|
||||||
|
controlNetId,
|
||||||
|
} = props.controlNet;
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const selector = useMemo(
|
const { pendingControlImages } = useAppSelector(selector);
|
||||||
() =>
|
|
||||||
createSelector(
|
|
||||||
stateSelector,
|
|
||||||
({ controlNet }) => {
|
|
||||||
const { pendingControlImages } = controlNet;
|
|
||||||
const {
|
|
||||||
controlImage,
|
|
||||||
processedControlImage,
|
|
||||||
processorType,
|
|
||||||
isEnabled,
|
|
||||||
} = controlNet.controlNets[controlNetId];
|
|
||||||
|
|
||||||
return {
|
|
||||||
controlImageName: controlImage,
|
|
||||||
processedControlImageName: processedControlImage,
|
|
||||||
processorType,
|
|
||||||
isEnabled,
|
|
||||||
pendingControlImages,
|
|
||||||
};
|
|
||||||
},
|
|
||||||
defaultSelectorOptions
|
|
||||||
),
|
|
||||||
[controlNetId]
|
|
||||||
);
|
|
||||||
|
|
||||||
const {
|
|
||||||
controlImageName,
|
|
||||||
processedControlImageName,
|
|
||||||
processorType,
|
|
||||||
pendingControlImages,
|
|
||||||
isEnabled,
|
|
||||||
} = useAppSelector(selector);
|
|
||||||
|
|
||||||
const [isMouseOverImage, setIsMouseOverImage] = useState(false);
|
const [isMouseOverImage, setIsMouseOverImage] = useState(false);
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,5 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { memo } from 'react';
|
||||||
import { stateSelector } from 'app/store/store';
|
import { ControlNetConfig } from '../store/controlNetSlice';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
|
||||||
import { memo, useMemo } from 'react';
|
|
||||||
import CannyProcessor from './processors/CannyProcessor';
|
import CannyProcessor from './processors/CannyProcessor';
|
||||||
import ContentShuffleProcessor from './processors/ContentShuffleProcessor';
|
import ContentShuffleProcessor from './processors/ContentShuffleProcessor';
|
||||||
import HedProcessor from './processors/HedProcessor';
|
import HedProcessor from './processors/HedProcessor';
|
||||||
@@ -17,28 +14,11 @@ import PidiProcessor from './processors/PidiProcessor';
|
|||||||
import ZoeDepthProcessor from './processors/ZoeDepthProcessor';
|
import ZoeDepthProcessor from './processors/ZoeDepthProcessor';
|
||||||
|
|
||||||
export type ControlNetProcessorProps = {
|
export type ControlNetProcessorProps = {
|
||||||
controlNetId: string;
|
controlNet: ControlNetConfig;
|
||||||
};
|
};
|
||||||
|
|
||||||
const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
|
const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
|
||||||
const { controlNetId } = props;
|
const { controlNetId, isEnabled, processorNode } = props.controlNet;
|
||||||
|
|
||||||
const selector = useMemo(
|
|
||||||
() =>
|
|
||||||
createSelector(
|
|
||||||
stateSelector,
|
|
||||||
({ controlNet }) => {
|
|
||||||
const { isEnabled, processorNode } =
|
|
||||||
controlNet.controlNets[controlNetId];
|
|
||||||
|
|
||||||
return { isEnabled, processorNode };
|
|
||||||
},
|
|
||||||
defaultSelectorOptions
|
|
||||||
),
|
|
||||||
[controlNetId]
|
|
||||||
);
|
|
||||||
|
|
||||||
const { isEnabled, processorNode } = useAppSelector(selector);
|
|
||||||
|
|
||||||
if (processorNode.type === 'canny_image_processor') {
|
if (processorNode.type === 'canny_image_processor') {
|
||||||
return (
|
return (
|
||||||
|
|||||||
@@ -1,34 +1,19 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
|
||||||
import { stateSelector } from 'app/store/store';
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
|
||||||
import IAISwitch from 'common/components/IAISwitch';
|
import IAISwitch from 'common/components/IAISwitch';
|
||||||
import { controlNetAutoConfigToggled } from 'features/controlNet/store/controlNetSlice';
|
import {
|
||||||
|
ControlNetConfig,
|
||||||
|
controlNetAutoConfigToggled,
|
||||||
|
} from 'features/controlNet/store/controlNetSlice';
|
||||||
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
controlNetId: string;
|
controlNet: ControlNetConfig;
|
||||||
};
|
};
|
||||||
|
|
||||||
const ParamControlNetShouldAutoConfig = (props: Props) => {
|
const ParamControlNetShouldAutoConfig = (props: Props) => {
|
||||||
const { controlNetId } = props;
|
const { controlNetId, isEnabled, shouldAutoConfig } = props.controlNet;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const selector = useMemo(
|
|
||||||
() =>
|
|
||||||
createSelector(
|
|
||||||
stateSelector,
|
|
||||||
({ controlNet }) => {
|
|
||||||
const { isEnabled, shouldAutoConfig } =
|
|
||||||
controlNet.controlNets[controlNetId];
|
|
||||||
return { isEnabled, shouldAutoConfig };
|
|
||||||
},
|
|
||||||
defaultSelectorOptions
|
|
||||||
),
|
|
||||||
[controlNetId]
|
|
||||||
);
|
|
||||||
|
|
||||||
const { isEnabled, shouldAutoConfig } = useAppSelector(selector);
|
|
||||||
const isBusy = useAppSelector(selectIsBusy);
|
const isBusy = useAppSelector(selectIsBusy);
|
||||||
|
|
||||||
const handleShouldAutoConfigChanged = useCallback(() => {
|
const handleShouldAutoConfigChanged = useCallback(() => {
|
||||||
|
|||||||
@@ -9,48 +9,39 @@ import {
|
|||||||
RangeSliderTrack,
|
RangeSliderTrack,
|
||||||
Tooltip,
|
Tooltip,
|
||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import { stateSelector } from 'app/store/store';
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
|
||||||
import {
|
import {
|
||||||
|
ControlNetConfig,
|
||||||
controlNetBeginStepPctChanged,
|
controlNetBeginStepPctChanged,
|
||||||
controlNetEndStepPctChanged,
|
controlNetEndStepPctChanged,
|
||||||
} from 'features/controlNet/store/controlNetSlice';
|
} from 'features/controlNet/store/controlNetSlice';
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
controlNetId: string;
|
controlNet: ControlNetConfig;
|
||||||
};
|
};
|
||||||
|
|
||||||
const formatPct = (v: number) => `${Math.round(v * 100)}%`;
|
const formatPct = (v: number) => `${Math.round(v * 100)}%`;
|
||||||
|
|
||||||
const ParamControlNetBeginEnd = (props: Props) => {
|
const ParamControlNetBeginEnd = (props: Props) => {
|
||||||
const { controlNetId } = props;
|
const { beginStepPct, endStepPct, isEnabled, controlNetId } =
|
||||||
|
props.controlNet;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const selector = useMemo(
|
|
||||||
() =>
|
|
||||||
createSelector(
|
|
||||||
stateSelector,
|
|
||||||
({ controlNet }) => {
|
|
||||||
const { beginStepPct, endStepPct, isEnabled } =
|
|
||||||
controlNet.controlNets[controlNetId];
|
|
||||||
return { beginStepPct, endStepPct, isEnabled };
|
|
||||||
},
|
|
||||||
defaultSelectorOptions
|
|
||||||
),
|
|
||||||
[controlNetId]
|
|
||||||
);
|
|
||||||
|
|
||||||
const { beginStepPct, endStepPct, isEnabled } = useAppSelector(selector);
|
|
||||||
|
|
||||||
const handleStepPctChanged = useCallback(
|
const handleStepPctChanged = useCallback(
|
||||||
(v: number[]) => {
|
(v: number[]) => {
|
||||||
dispatch(
|
dispatch(
|
||||||
controlNetBeginStepPctChanged({ controlNetId, beginStepPct: v[0] })
|
controlNetBeginStepPctChanged({
|
||||||
|
controlNetId,
|
||||||
|
beginStepPct: v[0] as number,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
dispatch(
|
||||||
|
controlNetEndStepPctChanged({
|
||||||
|
controlNetId,
|
||||||
|
endStepPct: v[1] as number,
|
||||||
|
})
|
||||||
);
|
);
|
||||||
dispatch(controlNetEndStepPctChanged({ controlNetId, endStepPct: v[1] }));
|
|
||||||
},
|
},
|
||||||
[controlNetId, dispatch]
|
[controlNetId, dispatch]
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -1,16 +1,14 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import { stateSelector } from 'app/store/store';
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
|
||||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||||
import {
|
import {
|
||||||
ControlModes,
|
ControlModes,
|
||||||
|
ControlNetConfig,
|
||||||
controlNetControlModeChanged,
|
controlNetControlModeChanged,
|
||||||
} from 'features/controlNet/store/controlNetSlice';
|
} from 'features/controlNet/store/controlNetSlice';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { useCallback } from 'react';
|
||||||
|
|
||||||
type ParamControlNetControlModeProps = {
|
type ParamControlNetControlModeProps = {
|
||||||
controlNetId: string;
|
controlNet: ControlNetConfig;
|
||||||
};
|
};
|
||||||
|
|
||||||
const CONTROL_MODE_DATA = [
|
const CONTROL_MODE_DATA = [
|
||||||
@@ -23,23 +21,8 @@ const CONTROL_MODE_DATA = [
|
|||||||
export default function ParamControlNetControlMode(
|
export default function ParamControlNetControlMode(
|
||||||
props: ParamControlNetControlModeProps
|
props: ParamControlNetControlModeProps
|
||||||
) {
|
) {
|
||||||
const { controlNetId } = props;
|
const { controlMode, isEnabled, controlNetId } = props.controlNet;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const selector = useMemo(
|
|
||||||
() =>
|
|
||||||
createSelector(
|
|
||||||
stateSelector,
|
|
||||||
({ controlNet }) => {
|
|
||||||
const { controlMode, isEnabled } =
|
|
||||||
controlNet.controlNets[controlNetId];
|
|
||||||
return { controlMode, isEnabled };
|
|
||||||
},
|
|
||||||
defaultSelectorOptions
|
|
||||||
),
|
|
||||||
[controlNetId]
|
|
||||||
);
|
|
||||||
|
|
||||||
const { controlMode, isEnabled } = useAppSelector(selector);
|
|
||||||
|
|
||||||
const handleControlModeChange = useCallback(
|
const handleControlModeChange = useCallback(
|
||||||
(controlMode: ControlModes) => {
|
(controlMode: ControlModes) => {
|
||||||
|
|||||||
@@ -5,7 +5,10 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
||||||
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
|
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
|
||||||
import { controlNetModelChanged } from 'features/controlNet/store/controlNetSlice';
|
import {
|
||||||
|
ControlNetConfig,
|
||||||
|
controlNetModelChanged,
|
||||||
|
} from 'features/controlNet/store/controlNetSlice';
|
||||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||||
import { modelIdToControlNetModelParam } from 'features/parameters/util/modelIdToControlNetModelParam';
|
import { modelIdToControlNetModelParam } from 'features/parameters/util/modelIdToControlNetModelParam';
|
||||||
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||||
@@ -14,30 +17,24 @@ import { memo, useCallback, useMemo } from 'react';
|
|||||||
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
|
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
type ParamControlNetModelProps = {
|
type ParamControlNetModelProps = {
|
||||||
controlNetId: string;
|
controlNet: ControlNetConfig;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const selector = createSelector(
|
||||||
|
stateSelector,
|
||||||
|
({ generation }) => {
|
||||||
|
const { model } = generation;
|
||||||
|
return { mainModel: model };
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
const ParamControlNetModel = (props: ParamControlNetModelProps) => {
|
const ParamControlNetModel = (props: ParamControlNetModelProps) => {
|
||||||
const { controlNetId } = props;
|
const { controlNetId, model: controlNetModel, isEnabled } = props.controlNet;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const isBusy = useAppSelector(selectIsBusy);
|
const isBusy = useAppSelector(selectIsBusy);
|
||||||
|
|
||||||
const selector = useMemo(
|
const { mainModel } = useAppSelector(selector);
|
||||||
() =>
|
|
||||||
createSelector(
|
|
||||||
stateSelector,
|
|
||||||
({ generation, controlNet }) => {
|
|
||||||
const { model } = generation;
|
|
||||||
const controlNetModel = controlNet.controlNets[controlNetId]?.model;
|
|
||||||
const isEnabled = controlNet.controlNets[controlNetId]?.isEnabled;
|
|
||||||
return { mainModel: model, controlNetModel, isEnabled };
|
|
||||||
},
|
|
||||||
defaultSelectorOptions
|
|
||||||
),
|
|
||||||
[controlNetId]
|
|
||||||
);
|
|
||||||
|
|
||||||
const { mainModel, controlNetModel, isEnabled } = useAppSelector(selector);
|
|
||||||
|
|
||||||
const { data: controlNetModels } = useGetControlNetModelsQuery();
|
const { data: controlNetModels } = useGetControlNetModelsQuery();
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { stateSelector } from 'app/store/store';
|
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAIMantineSearchableSelect, {
|
import IAIMantineSearchableSelect, {
|
||||||
IAISelectDataType,
|
IAISelectDataType,
|
||||||
@@ -9,13 +8,16 @@ import IAIMantineSearchableSelect, {
|
|||||||
import { configSelector } from 'features/system/store/configSelectors';
|
import { configSelector } from 'features/system/store/configSelectors';
|
||||||
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||||
import { map } from 'lodash-es';
|
import { map } from 'lodash-es';
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { CONTROLNET_PROCESSORS } from '../../store/constants';
|
import { CONTROLNET_PROCESSORS } from '../../store/constants';
|
||||||
import { controlNetProcessorTypeChanged } from '../../store/controlNetSlice';
|
import {
|
||||||
|
ControlNetConfig,
|
||||||
|
controlNetProcessorTypeChanged,
|
||||||
|
} from '../../store/controlNetSlice';
|
||||||
import { ControlNetProcessorType } from '../../store/types';
|
import { ControlNetProcessorType } from '../../store/types';
|
||||||
|
|
||||||
type ParamControlNetProcessorSelectProps = {
|
type ParamControlNetProcessorSelectProps = {
|
||||||
controlNetId: string;
|
controlNet: ControlNetConfig;
|
||||||
};
|
};
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
@@ -52,23 +54,9 @@ const ParamControlNetProcessorSelect = (
|
|||||||
props: ParamControlNetProcessorSelectProps
|
props: ParamControlNetProcessorSelectProps
|
||||||
) => {
|
) => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { controlNetId } = props;
|
const { controlNetId, isEnabled, processorNode } = props.controlNet;
|
||||||
const processorNodeSelector = useMemo(
|
|
||||||
() =>
|
|
||||||
createSelector(
|
|
||||||
stateSelector,
|
|
||||||
({ controlNet }) => {
|
|
||||||
const { isEnabled, processorNode } =
|
|
||||||
controlNet.controlNets[controlNetId];
|
|
||||||
return { isEnabled, processorNode };
|
|
||||||
},
|
|
||||||
defaultSelectorOptions
|
|
||||||
),
|
|
||||||
[controlNetId]
|
|
||||||
);
|
|
||||||
const isBusy = useAppSelector(selectIsBusy);
|
const isBusy = useAppSelector(selectIsBusy);
|
||||||
const controlNetProcessors = useAppSelector(selector);
|
const controlNetProcessors = useAppSelector(selector);
|
||||||
const { isEnabled, processorNode } = useAppSelector(processorNodeSelector);
|
|
||||||
|
|
||||||
const handleProcessorTypeChanged = useCallback(
|
const handleProcessorTypeChanged = useCallback(
|
||||||
(v: string | null) => {
|
(v: string | null) => {
|
||||||
|
|||||||
@@ -1,16 +1,14 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import { stateSelector } from 'app/store/store';
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
|
||||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||||
import {
|
import {
|
||||||
|
ControlNetConfig,
|
||||||
ResizeModes,
|
ResizeModes,
|
||||||
controlNetResizeModeChanged,
|
controlNetResizeModeChanged,
|
||||||
} from 'features/controlNet/store/controlNetSlice';
|
} from 'features/controlNet/store/controlNetSlice';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { useCallback } from 'react';
|
||||||
|
|
||||||
type ParamControlNetResizeModeProps = {
|
type ParamControlNetResizeModeProps = {
|
||||||
controlNetId: string;
|
controlNet: ControlNetConfig;
|
||||||
};
|
};
|
||||||
|
|
||||||
const RESIZE_MODE_DATA = [
|
const RESIZE_MODE_DATA = [
|
||||||
@@ -22,23 +20,8 @@ const RESIZE_MODE_DATA = [
|
|||||||
export default function ParamControlNetResizeMode(
|
export default function ParamControlNetResizeMode(
|
||||||
props: ParamControlNetResizeModeProps
|
props: ParamControlNetResizeModeProps
|
||||||
) {
|
) {
|
||||||
const { controlNetId } = props;
|
const { resizeMode, isEnabled, controlNetId } = props.controlNet;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const selector = useMemo(
|
|
||||||
() =>
|
|
||||||
createSelector(
|
|
||||||
stateSelector,
|
|
||||||
({ controlNet }) => {
|
|
||||||
const { resizeMode, isEnabled } =
|
|
||||||
controlNet.controlNets[controlNetId];
|
|
||||||
return { resizeMode, isEnabled };
|
|
||||||
},
|
|
||||||
defaultSelectorOptions
|
|
||||||
),
|
|
||||||
[controlNetId]
|
|
||||||
);
|
|
||||||
|
|
||||||
const { resizeMode, isEnabled } = useAppSelector(selector);
|
|
||||||
|
|
||||||
const handleResizeModeChange = useCallback(
|
const handleResizeModeChange = useCallback(
|
||||||
(resizeMode: ResizeModes) => {
|
(resizeMode: ResizeModes) => {
|
||||||
|
|||||||
@@ -1,32 +1,18 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import { stateSelector } from 'app/store/store';
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
|
||||||
import IAISlider from 'common/components/IAISlider';
|
import IAISlider from 'common/components/IAISlider';
|
||||||
import { controlNetWeightChanged } from 'features/controlNet/store/controlNetSlice';
|
import {
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
ControlNetConfig,
|
||||||
|
controlNetWeightChanged,
|
||||||
|
} from 'features/controlNet/store/controlNetSlice';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
|
||||||
type ParamControlNetWeightProps = {
|
type ParamControlNetWeightProps = {
|
||||||
controlNetId: string;
|
controlNet: ControlNetConfig;
|
||||||
};
|
};
|
||||||
|
|
||||||
const ParamControlNetWeight = (props: ParamControlNetWeightProps) => {
|
const ParamControlNetWeight = (props: ParamControlNetWeightProps) => {
|
||||||
const { controlNetId } = props;
|
const { weight, isEnabled, controlNetId } = props.controlNet;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const selector = useMemo(
|
|
||||||
() =>
|
|
||||||
createSelector(
|
|
||||||
stateSelector,
|
|
||||||
({ controlNet }) => {
|
|
||||||
const { weight, isEnabled } = controlNet.controlNets[controlNetId];
|
|
||||||
return { weight, isEnabled };
|
|
||||||
},
|
|
||||||
defaultSelectorOptions
|
|
||||||
),
|
|
||||||
[controlNetId]
|
|
||||||
);
|
|
||||||
|
|
||||||
const { weight, isEnabled } = useAppSelector(selector);
|
|
||||||
const handleWeightChanged = useCallback(
|
const handleWeightChanged = useCallback(
|
||||||
(weight: number) => {
|
(weight: number) => {
|
||||||
dispatch(controlNetWeightChanged({ controlNetId, weight }));
|
dispatch(controlNetWeightChanged({ controlNetId, weight }));
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user