Compare commits

..

43 Commits

Author SHA1 Message Date
psychedelicious
54bda8e8e4 chore: bump version to v6.2.0a1 2025-07-25 17:16:51 +10:00
psychedelicious
c14889f055 tidy(ui): enable devmode redux checks 2025-07-25 17:15:19 +10:00
psychedelicious
86680f296a chore(ui): lint 2025-07-25 17:15:19 +10:00
psychedelicious
de0b7801a6 fix(ui): infinite loop when setting tile controlnet model 2025-07-25 17:15:19 +10:00
psychedelicious
a03c7ca4e3 fix(ui): do not store whole model configs in state 2025-07-25 17:15:19 +10:00
psychedelicious
32af53779d refactor(ui): just manually validate async stuff 2025-07-25 17:15:19 +10:00
psychedelicious
a8662953fc refactor(ui): work around zod async validation issue 2025-07-25 17:15:19 +10:00
psychedelicious
82cdfd83e4 fix(ui): check initial retrieval and set as last persisted 2025-07-25 17:15:19 +10:00
psychedelicious
3f3fdf0b43 chore(ui): bump zod to latest
Checking if it fixes an issue w/ async validators
2025-07-25 17:15:18 +10:00
psychedelicious
53dbd5a7c9 refactor(ui): use zod for all redux state 2025-07-25 17:15:18 +10:00
psychedelicious
bbe5979349 refactor(ui): use zod for all redux state (wip)
needed for confidence w/ state rehydration logic
2025-07-25 17:15:18 +10:00
psychedelicious
ca70540ddd feat(ui): iterate on storage api 2025-07-25 17:15:18 +10:00
psychedelicious
37e25ccbf7 refactor(ui): restructure persistence driver creation to support custom drivers 2025-07-25 17:15:18 +10:00
psychedelicious
28e7a83f98 revert(ui): temp changes to main.tsx for testing 2025-07-25 17:15:18 +10:00
psychedelicious
3b39912b1c revert(ui): temp disable eslint rule 2025-07-25 17:15:18 +10:00
psychedelicious
c76698f205 git: update gitignore 2025-07-25 17:15:18 +10:00
psychedelicious
8f27a393d8 wip 2025-07-25 17:15:18 +10:00
psychedelicious
84ff6dbe69 chore: ruff 2025-07-25 17:15:18 +10:00
psychedelicious
4620a2137c tests(app): service mocks 2025-07-25 17:15:18 +10:00
psychedelicious
8ddbd979dd chore(ui): lint 2025-07-25 17:15:17 +10:00
psychedelicious
19ec9d268e refactor(ui): iterate on persistence 2025-07-25 17:15:17 +10:00
psychedelicious
ab683802ba refactor(ui): iterate on persistence 2025-07-25 17:15:17 +10:00
psychedelicious
98957ec9ea refactor(ui): alternate approach to slice configs 2025-07-25 17:15:17 +10:00
psychedelicious
7936ee9b7f chore(ui): typegen 2025-07-25 17:15:17 +10:00
psychedelicious
a96b7afdfb feat(api): make client state key query not body 2025-07-25 17:15:17 +10:00
psychedelicious
bb58a70b70 refactor(ui): cleaner slice definitions 2025-07-25 17:15:17 +10:00
psychedelicious
aaa1e1a480 feat: server-side client state persistence 2025-07-25 17:15:17 +10:00
Riccardo Giovanetti
7bea2fa11f translationBot(ui): update translation (Italian)
Currently translated at 98.6% (2016 of 2044 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.6% (2015 of 2043 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2025-07-25 17:15:01 +10:00
psychedelicious
169d58ea4c feat(ui): restore clear queue button
It is accessible in two places:
- The queue actions hamburger menu.
- On the queue tab.

If the clear queue app feature is disabled, it is not shown in either of
those places.
2025-07-23 23:38:53 +10:00
psychedelicious
b53d2250f7 feat(ui): reduce snap tolerance to make it easier to break the snap 2025-07-23 23:05:40 +10:00
psychedelicious
242eea8295 fix(ui): incorrect zoom direction w/ small scroll amounts 2025-07-23 23:05:40 +10:00
psychedelicious
4dabe09e0d tests(ui): remove test for no-longer-valid behaviour 2025-07-23 23:03:02 +10:00
psychedelicious
07fa0d3b77 fix(ui): do not attempt toggle when target panel isn't registered 2025-07-23 23:03:02 +10:00
psychedelicious
e97f82292f tests(ui): add tests for disposable handling 2025-07-23 23:03:02 +10:00
psychedelicious
005bab9035 fix(ui): tab disposables not being added correctly 2025-07-23 23:03:02 +10:00
psychedelicious
409173919c tests(ui): add tests for toggleViewer functionality 2025-07-23 23:03:02 +10:00
psychedelicious
7915180047 feat(ui): restore viewer toggle hotkey 2025-07-23 23:03:02 +10:00
Riccardo Giovanetti
4349b8387d translationBot(ui): update translation (Italian)
Currently translated at 97.9% (2000 of 2042 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2025-07-23 12:26:48 +10:00
Kent Keirsey
f95b686bdc reposition export button 2025-07-23 11:55:11 +10:00
Mary Hipp
72afb9c3fd fix iterations for all API models 2025-07-22 13:27:35 -04:00
Mary Hipp
f004fc31f1 update whats new 2025-07-22 12:24:10 -04:00
psychedelicious
2aa163b3a2 feat(ui): add default inpaint mask layer on canvas reset 2025-07-22 10:26:57 +10:00
psychedelicious
f40900c173 chore: bump version to v6.1.0 2025-07-22 08:24:31 +10:00
104 changed files with 2963 additions and 1344 deletions

View File

@@ -10,6 +10,7 @@ from invokeai.app.services.board_images.board_images_default import BoardImagesS
from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage
from invokeai.app.services.boards.boards_default import BoardService
from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService
from invokeai.app.services.client_state_persistence.client_state_persistence_sqlite import ClientStatePersistenceSqlite
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.download.download_default import DownloadQueueService
from invokeai.app.services.events.events_fastapievents import FastAPIEventService
@@ -151,6 +152,7 @@ class ApiDependencies:
style_preset_records = SqliteStylePresetRecordsStorage(db=db)
style_preset_image_files = StylePresetImageFileStorageDisk(style_presets_folder / "images")
workflow_thumbnails = WorkflowThumbnailFileStorageDisk(workflow_thumbnails_folder)
client_state_persistence = ClientStatePersistenceSqlite(db=db)
services = InvocationServices(
board_image_records=board_image_records,
@@ -181,6 +183,7 @@ class ApiDependencies:
style_preset_records=style_preset_records,
style_preset_image_files=style_preset_image_files,
workflow_thumbnails=workflow_thumbnails,
client_state_persistence=client_state_persistence,
)
ApiDependencies.invoker = Invoker(services)

View File

@@ -5,9 +5,9 @@ from pathlib import Path
from typing import Optional
import torch
from fastapi import Body
from fastapi import Body, HTTPException, Query
from fastapi.routing import APIRouter
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, JsonValue
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.invocations.upscale import ESRGAN_MODELS
@@ -173,3 +173,50 @@ async def disable_invocation_cache() -> None:
async def get_invocation_cache_status() -> InvocationCacheStatus:
"""Clears the invocation cache"""
return ApiDependencies.invoker.services.invocation_cache.get_status()
@app_router.get(
"/client_state",
operation_id="get_client_state_by_key",
response_model=JsonValue | None,
)
async def get_client_state_by_key(
key: str = Query(..., description="Key to get"),
) -> JsonValue | None:
"""Gets the client state"""
try:
return ApiDependencies.invoker.services.client_state_persistence.get_by_key(key)
except Exception as e:
logging.error(f"Error getting client state: {e}")
raise HTTPException(status_code=500, detail="Error setting client state")
@app_router.post(
"/client_state",
operation_id="set_client_state",
response_model=None,
)
async def set_client_state(
key: str = Query(..., description="Key to set"),
value: JsonValue = Body(..., description="Value of the key"),
) -> None:
"""Sets the client state"""
try:
ApiDependencies.invoker.services.client_state_persistence.set_by_key(key, value)
except Exception as e:
logging.error(f"Error setting client state: {e}")
raise HTTPException(status_code=500, detail="Error setting client state")
@app_router.delete(
"/client_state",
operation_id="delete_client_state",
responses={204: {"description": "Client state deleted"}},
)
async def delete_client_state() -> None:
"""Deletes the client state"""
try:
ApiDependencies.invoker.services.client_state_persistence.delete()
except Exception as e:
logging.error(f"Error deleting client state: {e}")
raise HTTPException(status_code=500, detail="Error deleting client state")

View File

@@ -0,0 +1,35 @@
from abc import ABC, abstractmethod
from pydantic import JsonValue
class ClientStatePersistenceABC(ABC):
"""
Base class for client persistence implementations.
This class defines the interface for persisting client data.
"""
@abstractmethod
def set_by_key(self, key: str, value: JsonValue) -> None:
"""
Store the data for the client.
:param data: The client data to be stored.
"""
pass
@abstractmethod
def get_by_key(self, key: str) -> JsonValue | None:
"""
Get the data for the client.
:return: The client data.
"""
pass
@abstractmethod
def delete(self) -> None:
"""
Delete the data for the client.
"""
pass

View File

@@ -0,0 +1,65 @@
import json
from pydantic import JsonValue
from invokeai.app.services.client_state_persistence.client_state_persistence_base import ClientStatePersistenceABC
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
class ClientStatePersistenceSqlite(ClientStatePersistenceABC):
"""
Base class for client persistence implementations.
This class defines the interface for persisting client data.
"""
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._db = db
self._default_row_id = 1
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
def set_by_key(self, key: str, value: JsonValue) -> None:
state = self.get() or {}
state.update({key: value})
with self._db.transaction() as cursor:
cursor.execute(
f"""
INSERT INTO client_state (id, data)
VALUES ({self._default_row_id}, ?)
ON CONFLICT(id) DO UPDATE
SET data = excluded.data;
""",
(json.dumps(state),),
)
def get(self) -> dict[str, JsonValue] | None:
with self._db.transaction() as cursor:
cursor.execute(
f"""
SELECT data FROM client_state
WHERE id = {self._default_row_id}
"""
)
row = cursor.fetchone()
if row is None:
return None
return json.loads(row[0])
def get_by_key(self, key: str) -> JsonValue | None:
state = self.get()
if state is None:
return None
return state.get(key, None)
def delete(self) -> None:
with self._db.transaction() as cursor:
cursor.execute(
f"""
DELETE FROM client_state
WHERE id = {self._default_row_id}
"""
)

View File

@@ -17,6 +17,7 @@ if TYPE_CHECKING:
from invokeai.app.services.board_records.board_records_base import BoardRecordStorageBase
from invokeai.app.services.boards.boards_base import BoardServiceABC
from invokeai.app.services.bulk_download.bulk_download_base import BulkDownloadBase
from invokeai.app.services.client_state_persistence.client_state_persistence_base import ClientStatePersistenceABC
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadQueueServiceBase
from invokeai.app.services.events.events_base import EventServiceBase
@@ -73,6 +74,7 @@ class InvocationServices:
style_preset_records: "StylePresetRecordsStorageBase",
style_preset_image_files: "StylePresetImageFileStorageBase",
workflow_thumbnails: "WorkflowThumbnailServiceBase",
client_state_persistence: "ClientStatePersistenceABC",
):
self.board_images = board_images
self.board_image_records = board_image_records
@@ -102,3 +104,4 @@ class InvocationServices:
self.style_preset_records = style_preset_records
self.style_preset_image_files = style_preset_image_files
self.workflow_thumbnails = workflow_thumbnails
self.client_state_persistence = client_state_persistence

View File

@@ -23,6 +23,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_17 import
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_18 import build_migration_18
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_19 import build_migration_19
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_20 import build_migration_20
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_21 import build_migration_21
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
@@ -63,6 +64,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_18())
migrator.register_migration(build_migration_19(app_config=config))
migrator.register_migration(build_migration_20())
migrator.register_migration(build_migration_21())
migrator.run_migrations()
return db

View File

@@ -0,0 +1,40 @@
import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration21Callback:
def __call__(self, cursor: sqlite3.Cursor) -> None:
cursor.execute(
"""
CREATE TABLE client_state (
id INTEGER PRIMARY KEY CHECK(id = 1),
data TEXT NOT NULL, -- Frontend will handle the shape of this data
updated_at DATETIME NOT NULL DEFAULT (CURRENT_TIMESTAMP)
);
"""
)
cursor.execute(
"""
CREATE TRIGGER tg_client_state_updated_at
AFTER UPDATE ON client_state
FOR EACH ROW
BEGIN
UPDATE client_state
SET updated_at = CURRENT_TIMESTAMP
WHERE id = OLD.id;
END;
"""
)
def build_migration_21() -> Migration:
"""Builds the migration object for migrating from version 20 to version 21. This includes:
- Creating the `client_state` table.
- Adding a trigger to update the `updated_at` field on updates.
"""
return Migration(
from_version=20,
to_version=21,
callback=Migration21Callback(),
)

View File

@@ -44,4 +44,5 @@ yalc.lock
# vitest
tsconfig.vitest-temp.json
coverage/
coverage/
*.tgz

View File

@@ -26,7 +26,7 @@ i18n.use(initReactI18next).init({
returnNull: false,
});
const store = createStore(undefined, false);
const store = createStore({ driver: { getItem: () => {}, setItem: () => {} }, persistThrottle: 2000 });
$store.set(store);
$baseUrl.set('http://localhost:9090');

View File

@@ -197,6 +197,10 @@ export default [
importNames: ['isEqual'],
message: 'Please use objectEquals from @observ33r/object-equals instead.',
},
{
name: 'zod/v3',
message: 'Import from zod instead.',
},
],
},
],

View File

@@ -63,7 +63,6 @@
"framer-motion": "^11.10.0",
"i18next": "^25.3.2",
"i18next-http-backend": "^3.0.2",
"idb-keyval": "6.2.2",
"jsondiffpatch": "^0.7.3",
"konva": "^9.3.22",
"linkify-react": "^4.3.1",
@@ -103,7 +102,7 @@
"use-debounce": "^10.0.5",
"use-device-pixel-ratio": "^1.1.2",
"uuid": "^11.1.0",
"zod": "^4.0.5",
"zod": "^4.0.10",
"zod-validation-error": "^3.5.2"
},
"peerDependencies": {

View File

@@ -80,9 +80,6 @@ importers:
i18next-http-backend:
specifier: ^3.0.2
version: 3.0.2
idb-keyval:
specifier: 6.2.2
version: 6.2.2
jsondiffpatch:
specifier: ^0.7.3
version: 0.7.3
@@ -201,11 +198,11 @@ importers:
specifier: ^11.1.0
version: 11.1.0
zod:
specifier: ^4.0.5
version: 4.0.5
specifier: ^4.0.10
version: 4.0.10
zod-validation-error:
specifier: ^3.5.2
version: 3.5.3(zod@4.0.5)
version: 3.5.3(zod@4.0.10)
devDependencies:
'@eslint/js':
specifier: ^9.31.0
@@ -411,6 +408,10 @@ packages:
resolution: {integrity: sha512-vbavdySgbTTrmFE+EsiqUTzlOr5bzlnJtUv9PynGCAKvfQqjIXbvFdumPM/GxMDfyuGMJaJAU6TO4zc1Jf1i8Q==}
engines: {node: '>=6.9.0'}
'@babel/runtime@7.28.2':
resolution: {integrity: sha512-KHp2IflsnGywDjBWDkR9iEqiWSpc8GIi0lgTT3mOElT0PP1tG26P4tmFI2YvAdzgq9RGyoHZQEIEdZy6Ec5xCA==}
engines: {node: '>=6.9.0'}
'@babel/template@7.27.2':
resolution: {integrity: sha512-LPDZ85aEJyYSd18/DkjNh4/y1ntkE5KwUHWTiqgRxruuZL2F1yuHligVHLvcHY2vMHXttKFpJn6LwfI7cw7ODw==}
engines: {node: '>=6.9.0'}
@@ -2771,9 +2772,6 @@ packages:
typescript:
optional: true
idb-keyval@6.2.2:
resolution: {integrity: sha512-yjD9nARJ/jb1g+CvD0tlhUHOrJ9Sy0P8T9MF3YaLlHnSRpwPfpTX0XIvpmw3gAJUmEu3FiICLBDPXVwyEvrleg==}
ieee754@1.2.1:
resolution: {integrity: sha512-dcyqhDvX1C46lXZcVqCpK+FtMRQVdIMN6/Df5js2zouUsqG7I6sFxitIC+7KYK29KdXOLHdu9zL4sFnoVQnqaA==}
@@ -4511,8 +4509,8 @@ packages:
zod@3.25.76:
resolution: {integrity: sha512-gzUt/qt81nXsFGKIFcC3YnfEAx5NkunCfnDlvuBSSFS02bcXu4Lmea0AFIUwbLWxWPx3d9p8S5QoaujKcNQxcQ==}
zod@4.0.5:
resolution: {integrity: sha512-/5UuuRPStvHXu7RS+gmvRf4NXrNxpSllGwDnCBcJZtQsKrviYXm54yDGV2KYNLT5kq0lHGcl7lqWJLgSaG+tgA==}
zod@4.0.10:
resolution: {integrity: sha512-3vB+UU3/VmLL2lvwcY/4RV2i9z/YU0DTV/tDuYjrwmx5WeJ7hwy+rGEEx8glHp6Yxw7ibRbKSaIFBgReRPe5KA==}
zustand@4.5.7:
resolution: {integrity: sha512-CHOUy7mu3lbD6o6LJLfllpjkzhHXSBlX8B9+qPddUsIfeF5S/UZ5q0kmCsnRqT1UHFQZchNFDDzMbQsuesHWlw==}
@@ -4633,6 +4631,8 @@ snapshots:
'@babel/runtime@7.27.6': {}
'@babel/runtime@7.28.2': {}
'@babel/template@7.27.2':
dependencies:
'@babel/code-frame': 7.27.1
@@ -5736,7 +5736,7 @@ snapshots:
'@testing-library/dom@10.4.0':
dependencies:
'@babel/code-frame': 7.27.1
'@babel/runtime': 7.27.6
'@babel/runtime': 7.28.2
'@types/aria-query': 5.0.4
aria-query: 5.3.0
chalk: 4.1.2
@@ -7266,8 +7266,6 @@ snapshots:
optionalDependencies:
typescript: 5.8.3
idb-keyval@6.2.2: {}
ieee754@1.2.1: {}
ignore@5.3.2: {}
@@ -9062,13 +9060,13 @@ snapshots:
dependencies:
zod: 3.25.76
zod-validation-error@3.5.3(zod@4.0.5):
zod-validation-error@3.5.3(zod@4.0.10):
dependencies:
zod: 4.0.5
zod: 4.0.10
zod@3.25.76: {}
zod@4.0.5: {}
zod@4.0.10: {}
zustand@4.5.7(@types/react@18.3.23)(immer@10.1.1)(react@18.3.1):
dependencies:

View File

@@ -253,6 +253,7 @@
"cancel": "Cancel",
"cancelAllExceptCurrentQueueItemAlertDialog": "Canceling all queue items except the current one will stop pending items but allow the in-progress one to finish.",
"cancelAllExceptCurrentQueueItemAlertDialog2": "Are you sure you want to cancel all pending queue items?",
"cancelAllExceptCurrent": "Cancel All Except Current",
"cancelAllExceptCurrentTooltip": "Cancel All Except Current Item",
"cancelTooltip": "Cancel Current Item",
"cancelSucceeded": "Item Canceled",
@@ -273,7 +274,7 @@
"retryItem": "Retry Item",
"cancelBatchSucceeded": "Batch Canceled",
"cancelBatchFailed": "Problem Canceling Batch",
"clearQueueAlertDialog": "Clearing the queue immediately cancels any processing items and clears the queue entirely. Pending filters will be canceled.",
"clearQueueAlertDialog": "Clearing the queue immediately cancels any processing items and clears the queue entirely. Pending filters will be canceled and the Canvas Staging Area will be reset.",
"clearQueueAlertDialog2": "Are you sure you want to clear the queue?",
"current": "Current",
"next": "Next",
@@ -2630,9 +2631,10 @@
"whatsNew": {
"whatsNewInInvoke": "What's New in Invoke",
"items": [
"Generate images faster with new Launchpads and a simplified Generate tab.",
"Edit with prompts using Flux Kontext Dev.",
"Export to PSD, bulk-hide overlays, organize models & images — all in a reimagined interface built for control."
"New setting to send all Canvas generations directly to the Gallery.",
"New Invert Mask (Shift+V) and Fit BBox to Mask (Shift+B) capabilities.",
"Expanded support for Model Thumbnails and configurations.",
"Various other quality of life updates and fixes"
],
"readReleaseNotes": "Read Release Notes",
"watchRecentReleaseVideos": "Watch Recent Release Videos",

View File

@@ -254,12 +254,16 @@
"desc": "Attiva/disattiva il pannello destro."
},
"resetPanelLayout": {
"title": "Ripristina il layout del pannello",
"desc": "Ripristina le dimensioni e il layout predefiniti dei pannelli sinistro e destro."
"title": "Ripristina lo schema del pannello",
"desc": "Ripristina le dimensioni e lo schema predefiniti dei pannelli sinistro e destro."
},
"togglePanels": {
"title": "Attiva/disattiva i pannelli",
"desc": "Mostra o nascondi contemporaneamente i pannelli sinistro e destro."
},
"selectGenerateTab": {
"title": "Seleziona la scheda Genera",
"desc": "Seleziona la scheda Genera."
}
},
"hotkeys": "Tasti di scelta rapida",
@@ -389,6 +393,23 @@
"behavior": "Comportamento",
"display": "Mostra",
"grid": "Griglia"
},
"invertMask": {
"title": "Inverti maschera",
"desc": "Inverte la maschera di inpaint selezionata, creando una nuova maschera con trasparenza opposta."
},
"fitBboxToMasks": {
"title": "Adatta il riquadro di delimitazione alle maschere",
"desc": "Regola automaticamente il riquadro di delimitazione della generazione per adattarlo alle maschere di inpaint visibili"
},
"applySegmentAnything": {
"title": "Applica Segment Anything",
"desc": "Applica la maschera Segment Anything corrente.",
"key": "invio"
},
"cancelSegmentAnything": {
"title": "Annulla Segment Anything",
"desc": "Annulla l'operazione Segment Anything corrente."
}
},
"workflows": {
@@ -518,6 +539,10 @@
"galleryNavUpAlt": {
"desc": "Uguale a Naviga verso l'alto, ma seleziona l'immagine da confrontare, aprendo la modalità di confronto se non è già aperta.",
"title": "Naviga verso l'alto (Confronta immagine)"
},
"starImage": {
"desc": "Aggiungi/Rimuovi contrassegno all'immagine selezionata.",
"title": "Aggiungi / Rimuovi contrassegno immagine"
}
}
},
@@ -936,7 +961,15 @@
"canvasManagerNotAvailable": "Gestione tela non disponibile",
"promptExpansionFailed": "Abbiamo riscontrato un problema. Riprova a eseguire l'espansione del prompt.",
"uploadAndPromptGenerationFailed": "Impossibile caricare l'immagine e generare il prompt",
"promptGenerationStarted": "Generazione del prompt avviata"
"promptGenerationStarted": "Generazione del prompt avviata",
"invalidBboxDesc": "Il riquadro di delimitazione non ha dimensioni valide",
"invalidBbox": "Riquadro di delimitazione non valido",
"noInpaintMaskSelectedDesc": "Seleziona una maschera di inpaint da invertire",
"noInpaintMaskSelected": "Nessuna maschera di inpaint selezionata",
"noVisibleMasksDesc": "Crea o abilita almeno una maschera inpaint da invertire",
"noVisibleMasks": "Nessuna maschera visibile",
"maskInvertFailed": "Impossibile invertire la maschera",
"maskInverted": "Maschera invertita"
},
"accessibility": {
"invokeProgressBar": "Barra di avanzamento generazione",
@@ -1131,7 +1164,22 @@
"missingField_withName": "Campo \"{{name}}\" mancante",
"unknownFieldEditWorkflowToFix_withName": "Il flusso di lavoro contiene un campo \"{{name}}\" sconosciuto .\nModifica il flusso di lavoro per risolvere il problema.",
"unexpectedField_withName": "Campo \"{{name}}\" inaspettato",
"missingSourceOrTargetHandle": "Identificatore del nodo sorgente o di destinazione mancante"
"missingSourceOrTargetHandle": "Identificatore del nodo sorgente o di destinazione mancante",
"layout": {
"alignmentDR": "In basso a destra",
"autoLayout": "Schema automatico",
"nodeSpacing": "Spaziatura nodi",
"layerSpacing": "Spaziatura livelli",
"layeringStrategy": "Strategia livelli",
"longestPath": "Percorso più lungo",
"layoutDirection": "Direzione schema",
"layoutDirectionRight": "Orizzontale",
"layoutDirectionDown": "Verticale",
"alignment": "Allineamento nodi",
"alignmentUL": "In alto a sinistra",
"alignmentDL": "In basso a sinistra",
"alignmentUR": "In alto a destra"
}
},
"boards": {
"autoAddBoard": "Aggiungi automaticamente bacheca",
@@ -1208,7 +1256,7 @@
"batchQueuedDesc_other": "Aggiunte {{count}} sessioni a {{direction}} della coda",
"graphQueued": "Grafico in coda",
"batch": "Lotto",
"clearQueueAlertDialog": "Lo svuotamento della coda annulla immediatamente tutti gli elementi in elaborazione e cancella completamente la coda. I filtri in sospeso verranno annullati.",
"clearQueueAlertDialog": "La cancellazione della coda annulla immediatamente tutti gli elementi in elaborazione e cancella completamente la coda. I filtri in sospeso verranno annullati e l'area di lavoro della Tela verrà reimpostata.",
"pending": "In attesa",
"completedIn": "Completato in",
"resumeFailed": "Problema nel riavvio dell'elaborazione",
@@ -1264,7 +1312,8 @@
"retrySucceeded": "Elemento rieseguito",
"retryItem": "Riesegui elemento",
"retryFailed": "Problema riesecuzione elemento",
"credits": "Crediti"
"credits": "Crediti",
"cancelAllExceptCurrent": "Annulla tutto tranne quello corrente"
},
"models": {
"noMatchingModels": "Nessun modello corrispondente",
@@ -1679,7 +1728,7 @@
"structure": {
"heading": "Struttura",
"paragraphs": [
"La struttura determina quanto l'immagine finale rispecchierà il layout dell'originale. Una struttura bassa permette cambiamenti significativi, mentre una struttura alta conserva la composizione e il layout originali."
"La struttura determina quanto l'immagine finale rispecchierà il layout dell'originale. Un valore struttura basso permette cambiamenti significativi, mentre un valore struttura alto conserva la composizione e lo schema originali."
]
},
"fluxDevLicense": {
@@ -1845,7 +1894,7 @@
"opened": "Aperto",
"convertGraph": "Converti grafico",
"loadWorkflow": "$t(common.load) Flusso di lavoro",
"autoLayout": "Disposizione automatica",
"autoLayout": "Schema automatico",
"loadFromGraph": "Carica il flusso di lavoro dal grafico",
"userWorkflows": "Flussi di lavoro utente",
"projectWorkflows": "Flussi di lavoro del progetto",
@@ -2444,7 +2493,9 @@
"switchOnStart": "All'inizio",
"switchOnFinish": "Alla fine",
"off": "Spento"
}
},
"invertMask": "Inverti maschera",
"fitBboxToMasks": "Adatta il riquadro di delimitazione alle maschere"
},
"ui": {
"tabs": {
@@ -2597,9 +2648,10 @@
"watchRecentReleaseVideos": "Guarda i video su questa versione",
"watchUiUpdatesOverview": "Guarda le novità dell'interfaccia",
"items": [
"Genera immagini più velocemente con le nuove Rampe di lancio e una scheda Genera semplificata.",
"Modifica con prompt utilizzando Flux Kontext Dev.",
"Esporta in PSD, nascondi sovrapposizioni in blocco, organizza modelli e immagini: il tutto in un'interfaccia riprogettata e pensata per il controllo."
"Nuova impostazione per inviare tutte le generazioni della Tela direttamente alla Galleria.",
"Nuove funzionalità Inverti maschera (Maiusc+V) e Adatta il Riquadro di delimitazione alla maschera (Maiusc+B).",
"Supporto esteso per miniature e configurazioni dei modelli.",
"Vari altri aggiornamenti e correzioni per la qualità della vita"
]
},
"system": {

View File

@@ -2,10 +2,10 @@ import { Box } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { GlobalHookIsolator } from 'app/components/GlobalHookIsolator';
import { GlobalModalIsolator } from 'app/components/GlobalModalIsolator';
import { useClearStorage } from 'app/contexts/clear-storage-context';
import { $didStudioInit, type StudioInitAction } from 'app/hooks/useStudioInitAction';
import type { PartialAppConfig } from 'app/types/invokeai';
import Loading from 'common/components/Loading/Loading';
import { useClearStorage } from 'common/hooks/useClearStorage';
import { AppContent } from 'features/ui/components/AppContent';
import { memo, useCallback } from 'react';
import { ErrorBoundary } from 'react-error-boundary';

View File

@@ -1,10 +1,12 @@
import 'i18n';
import type { Middleware } from '@reduxjs/toolkit';
import { ClearStorageProvider } from 'app/contexts/clear-storage-context';
import type { StudioInitAction } from 'app/hooks/useStudioInitAction';
import { $didStudioInit } from 'app/hooks/useStudioInitAction';
import type { LoggingOverrides } from 'app/logging/logger';
import { $loggingOverrides, configureLogging } from 'app/logging/logger';
import { buildStorageApi } from 'app/store/enhancers/reduxRemember/driver';
import { $accountSettingsLink } from 'app/store/nanostores/accountSettingsLink';
import { $authToken } from 'app/store/nanostores/authToken';
import { $baseUrl } from 'app/store/nanostores/baseUrl';
@@ -70,6 +72,14 @@ interface Props extends PropsWithChildren {
* If provided, overrides in-app navigation to the model manager
*/
onClickGoToModelManager?: () => void;
storageConfig?: {
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
getItem: (key: string) => Promise<any>;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
setItem: (key: string, value: any) => Promise<any>;
clear: () => Promise<void>;
persistThrottle: number;
};
}
const InvokeAIUI = ({
@@ -96,6 +106,7 @@ const InvokeAIUI = ({
loggingOverrides,
onClickGoToModelManager,
whatsNew,
storageConfig,
}: Props) => {
useLayoutEffect(() => {
/*
@@ -308,9 +319,21 @@ const InvokeAIUI = ({
};
}, [isDebugging]);
const storage = useMemo(() => buildStorageApi(storageConfig), [storageConfig]);
useEffect(() => {
const storageCleanup = storage.registerListeners();
return () => {
storageCleanup();
};
}, [storage]);
const store = useMemo(() => {
return createStore(projectId);
}, [projectId]);
return createStore({
driver: storage.reduxRememberDriver,
persistThrottle: storageConfig?.persistThrottle ?? 2000,
});
}, [storage.reduxRememberDriver, storageConfig?.persistThrottle]);
useEffect(() => {
$store.set(store);
@@ -327,11 +350,13 @@ const InvokeAIUI = ({
return (
<React.StrictMode>
<Provider store={store}>
<React.Suspense fallback={<Loading />}>
<App config={config} studioInitAction={studioInitAction} />
</React.Suspense>
</Provider>
<ClearStorageProvider value={storage.clearStorage}>
<Provider store={store}>
<React.Suspense fallback={<Loading />}>
<App config={config} studioInitAction={studioInitAction} />
</React.Suspense>
</Provider>
</ClearStorageProvider>
</React.StrictMode>
);
};

View File

@@ -0,0 +1,10 @@
import { createContext, useContext } from 'react';
const ClearStorageContext = createContext<() => void>(() => {});
export const ClearStorageProvider = ClearStorageContext.Provider;
export const useClearStorage = () => {
const context = useContext(ClearStorageContext);
return context;
};

View File

@@ -1,3 +1,2 @@
export const STORAGE_PREFIX = '@@invokeai-';
export const EMPTY_ARRAY = [];
export const EMPTY_OBJECT = {};

View File

@@ -1,40 +1,243 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import { logger } from 'app/logging/logger';
import { StorageError } from 'app/store/enhancers/reduxRemember/errors';
import { $projectId } from 'app/store/nanostores/projectId';
import type { UseStore } from 'idb-keyval';
import { clear, createStore as createIDBKeyValStore, get, set } from 'idb-keyval';
import { atom } from 'nanostores';
import type { Driver } from 'redux-remember';
import type { Driver as ReduxRememberDriver } from 'redux-remember';
import { getBaseUrl } from 'services/api';
import { buildAppInfoUrl } from 'services/api/endpoints/appInfo';
// Create a custom idb-keyval store (just needed to customize the name)
const $idbKeyValStore = atom<UseStore>(createIDBKeyValStore('invoke', 'invoke-store'));
const log = logger('system');
export const clearIdbKeyValStore = () => {
clear($idbKeyValStore.get());
const buildOSSServerBackedDriver = (): {
reduxRememberDriver: ReduxRememberDriver;
clearStorage: () => Promise<void>;
registerListeners: () => () => void;
} => {
// Persistence happens per slice. To track when persistence is in progress, maintain a ref count, incrementing
// it when a slice is being persisted and decrementing it when the persistence is done.
let persistRefCount = 0;
// Keep track of the last persisted state for each key to avoid unnecessary network requests.
//
// `redux-remember` persists individual slices of state, so we can implicity denylist a slice by not giving it a
// persist config.
//
// However, we may need to avoid persisting individual _fields_ of a slice. `redux-remember` does not provide a
// way to do this directly.
//
// To accomplish this, we add a layer of logic on top of the `redux-remember`. In the state serializer function
// provided to `redux-remember`, we can omit certain fields from the state that we do not want to persist. See
// the implementation in `store.ts` for this logic.
//
// This logic is unknown to `redux-remember`. When an omitted field changes, it will still attempt to persist the
// whole slice, even if the final, _serialized_ slice value is unchanged.
//
// To avoid unnecessary network requests, we keep track of the last persisted state for each key. If the value to
// be persisted is the same as the last persisted value, we can skip the network request.
const lastPersistedState = new Map<string, unknown>();
const getUrl = (key?: string) => {
const baseUrl = getBaseUrl();
const query: Record<string, string> = {};
if (key) {
query['key'] = key;
}
const path = buildAppInfoUrl('client_state', query);
const url = `${baseUrl}/${path}`;
return url;
};
const reduxRememberDriver: ReduxRememberDriver = {
getItem: async (key) => {
try {
const url = getUrl(key);
const res = await fetch(url, { method: 'GET' });
if (!res.ok) {
throw new Error(`Response status: ${res.status}`);
}
const text = await res.text();
if (!lastPersistedState.get(key)) {
lastPersistedState.set(key, text);
}
return JSON.parse(text);
} catch (originalError) {
throw new StorageError({
key,
projectId: $projectId.get(),
originalError,
});
}
},
setItem: async (key, value) => {
try {
persistRefCount++;
if (lastPersistedState.get(key) === value) {
log.trace(`Skipping persist for key "${key}" as value is unchanged.`);
return value;
}
const url = getUrl(key);
const headers = new Headers({
'Content-Type': 'application/json',
});
const res = await fetch(url, { method: 'POST', headers, body: value });
if (!res.ok) {
throw new Error(`Response status: ${res.status}`);
}
lastPersistedState.set(key, value);
return value;
} catch (originalError) {
throw new StorageError({
key,
value,
projectId: $projectId.get(),
originalError,
});
} finally {
persistRefCount--;
if (persistRefCount < 0) {
log.trace('Persist ref count is negative, resetting to 0');
persistRefCount = 0;
}
}
},
};
const clearStorage = async () => {
try {
persistRefCount++;
const url = getUrl();
const res = await fetch(url, { method: 'DELETE' });
if (!res.ok) {
throw new Error(`Response status: ${res.status}`);
}
} catch {
log.error('Failed to reset client state');
} finally {
persistRefCount--;
lastPersistedState.clear();
if (persistRefCount < 0) {
log.trace('Persist ref count is negative, resetting to 0');
persistRefCount = 0;
}
}
};
const registerListeners = () => {
const onBeforeUnload = (e: BeforeUnloadEvent) => {
if (persistRefCount > 0) {
e.preventDefault();
}
};
window.addEventListener('beforeunload', onBeforeUnload);
return () => {
window.removeEventListener('beforeunload', onBeforeUnload);
};
};
return { reduxRememberDriver, clearStorage, registerListeners };
};
// Create redux-remember driver, wrapping idb-keyval
export const idbKeyValDriver: Driver = {
getItem: (key) => {
const buildCustomDriver = (api: {
getItem: (key: string) => Promise<any>;
setItem: (key: string, value: any) => Promise<any>;
clear: () => Promise<void>;
}): {
reduxRememberDriver: ReduxRememberDriver;
clearStorage: () => Promise<void>;
registerListeners: () => () => void;
} => {
// See the comment in `buildOSSServerBackedDriver` for an explanation of this variable.
let persistRefCount = 0;
// See the comment in `buildOSSServerBackedDriver` for an explanation of this variable.
const lastPersistedState = new Map<string, unknown>();
const reduxRememberDriver: ReduxRememberDriver = {
getItem: async (key) => {
try {
log.trace(`Getting client state for key "${key}"`);
return await api.getItem(key);
} catch (originalError) {
throw new StorageError({
key,
projectId: $projectId.get(),
originalError,
});
}
},
setItem: async (key, value) => {
try {
persistRefCount++;
if (lastPersistedState.get(key) === value) {
log.trace(`Skipping setting client state for key "${key}" as value is unchanged`);
return value;
}
log.trace(`Setting client state for key "${key}", ${value}`);
await api.setItem(key, value);
lastPersistedState.set(key, value);
return value;
} catch (originalError) {
throw new StorageError({
key,
value,
projectId: $projectId.get(),
originalError,
});
} finally {
persistRefCount--;
if (persistRefCount < 0) {
log.trace('Persist ref count is negative, resetting to 0');
persistRefCount = 0;
}
}
},
};
const clearStorage = async () => {
try {
return get(key, $idbKeyValStore.get());
} catch (originalError) {
throw new StorageError({
key,
projectId: $projectId.get(),
originalError,
});
persistRefCount++;
log.trace('Clearing client state');
await api.clear();
} catch {
log.error('Failed to clear client state');
} finally {
persistRefCount--;
lastPersistedState.clear();
if (persistRefCount < 0) {
log.trace('Persist ref count is negative, resetting to 0');
persistRefCount = 0;
}
}
},
setItem: (key, value) => {
try {
return set(key, value, $idbKeyValStore.get());
} catch (originalError) {
throw new StorageError({
key,
value,
projectId: $projectId.get(),
originalError,
});
}
},
};
const registerListeners = () => {
const onBeforeUnload = (e: BeforeUnloadEvent) => {
if (persistRefCount > 0) {
e.preventDefault();
}
};
window.addEventListener('beforeunload', onBeforeUnload);
return () => {
window.removeEventListener('beforeunload', onBeforeUnload);
};
};
return { reduxRememberDriver, clearStorage, registerListeners };
};
export const buildStorageApi = (api?: {
getItem: (key: string) => Promise<any>;
setItem: (key: string, value: any) => Promise<any>;
clear: () => Promise<void>;
}) => {
if (api) {
return buildCustomDriver(api);
} else {
return buildOSSServerBackedDriver();
}
};

View File

@@ -1,73 +0,0 @@
import type { TypedStartListening } from '@reduxjs/toolkit';
import { addListener, createListenerMiddleware } from '@reduxjs/toolkit';
import { addAdHocPostProcessingRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/addAdHocPostProcessingRequestedListener';
import { addAnyEnqueuedListener } from 'app/store/middleware/listenerMiddleware/listeners/anyEnqueued';
import { addAppConfigReceivedListener } from 'app/store/middleware/listenerMiddleware/listeners/appConfigReceived';
import { addAppStartedListener } from 'app/store/middleware/listenerMiddleware/listeners/appStarted';
import { addBatchEnqueuedListener } from 'app/store/middleware/listenerMiddleware/listeners/batchEnqueued';
import { addDeleteBoardAndImagesFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/boardAndImagesDeleted';
import { addBoardIdSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/boardIdSelected';
import { addBulkDownloadListeners } from 'app/store/middleware/listenerMiddleware/listeners/bulkDownload';
import { addGetOpenAPISchemaListener } from 'app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema';
import { addImageAddedToBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageAddedToBoard';
import { addImageRemovedFromBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageRemovedFromBoard';
import { addImageUploadedFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageUploaded';
import { addModelSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelSelected';
import { addModelsLoadedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelsLoaded';
import { addSetDefaultSettingsListener } from 'app/store/middleware/listenerMiddleware/listeners/setDefaultSettings';
import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketConnected';
import type { AppDispatch, RootState } from 'app/store/store';
import { addArchivedOrDeletedBoardListener } from './listeners/addArchivedOrDeletedBoardListener';
export const listenerMiddleware = createListenerMiddleware();
export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
const startAppListening = listenerMiddleware.startListening as AppStartListening;
export const addAppListener = addListener.withTypes<RootState, AppDispatch>();
/**
* The RTK listener middleware is a lightweight alternative sagas/observables.
*
* Most side effect logic should live in a listener.
*/
// Image uploaded
addImageUploadedFulfilledListener(startAppListening);
// Image deleted
addDeleteBoardAndImagesFulfilledListener(startAppListening);
// User Invoked
addAnyEnqueuedListener(startAppListening);
addBatchEnqueuedListener(startAppListening);
// Socket.IO
addSocketConnectedEventListener(startAppListening);
// Gallery bulk download
addBulkDownloadListeners(startAppListening);
// Boards
addImageAddedToBoardFulfilledListener(startAppListening);
addImageRemovedFromBoardFulfilledListener(startAppListening);
addBoardIdSelectedListener(startAppListening);
addArchivedOrDeletedBoardListener(startAppListening);
// Node schemas
addGetOpenAPISchemaListener(startAppListening);
// Models
addModelSelectedListener(startAppListening);
// app startup
addAppStartedListener(startAppListening);
addModelsLoadedListener(startAppListening);
addAppConfigReceivedListener(startAppListening);
// Ad-hoc upscale workflwo
addAdHocPostProcessingRequestedListener(startAppListening);
addSetDefaultSettingsListener(startAppListening);

View File

@@ -1,6 +1,6 @@
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppStartListening } from 'app/store/store';
import { buildAdHocPostProcessingGraph } from 'features/nodes/util/graph/buildAdHocPostProcessingGraph';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';

View File

@@ -1,5 +1,5 @@
import { isAnyOf } from '@reduxjs/toolkit';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppStartListening } from 'app/store/store';
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
import {
autoAddBoardIdChanged,

View File

@@ -1,4 +1,4 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppStartListening } from 'app/store/store';
import { queueApi, selectQueueStatus } from 'services/api/endpoints/queue';
export const addAnyEnqueuedListener = (startAppListening: AppStartListening) => {

View File

@@ -1,4 +1,4 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppStartListening } from 'app/store/store';
import { setInfillMethod } from 'features/controlLayers/store/paramsSlice';
import { shouldUseNSFWCheckerChanged, shouldUseWatermarkerChanged } from 'features/system/store/systemSlice';
import { appInfoApi } from 'services/api/endpoints/appInfo';

View File

@@ -1,5 +1,5 @@
import { createAction } from '@reduxjs/toolkit';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppStartListening } from 'app/store/store';
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { imagesApi } from 'services/api/endpoints/images';

View File

@@ -1,5 +1,5 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppStartListening } from 'app/store/store';
import { truncate } from 'es-toolkit/compat';
import { zPydanticValidationError } from 'features/system/store/zodSchemas';
import { toast } from 'features/toast/toast';

View File

@@ -1,4 +1,4 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppStartListening } from 'app/store/store';
import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { getImageUsage } from 'features/deleteImageModal/store/state';

View File

@@ -1,5 +1,5 @@
import { isAnyOf } from '@reduxjs/toolkit';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppStartListening } from 'app/store/store';
import { selectGetImageNamesQueryArgs, selectSelectedBoardId } from 'features/gallery/store/gallerySelectors';
import { boardIdSelected, galleryViewChanged, imageSelected } from 'features/gallery/store/gallerySlice';
import { imagesApi } from 'services/api/endpoints/images';

View File

@@ -1,5 +1,5 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppStartListening } from 'app/store/store';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';

View File

@@ -1,5 +1,5 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppStartListening } from 'app/store/store';
import { parseify } from 'common/util/serialize';
import { size } from 'es-toolkit/compat';
import { $templates } from 'features/nodes/store/nodesSlice';

View File

@@ -1,5 +1,5 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppStartListening } from 'app/store/store';
import { imagesApi } from 'services/api/endpoints/images';
const log = logger('gallery');

View File

@@ -1,5 +1,5 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppStartListening } from 'app/store/store';
import { imagesApi } from 'services/api/endpoints/images';
const log = logger('gallery');

View File

@@ -1,7 +1,6 @@
import { isAnyOf } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { RootState } from 'app/store/store';
import type { AppStartListening, RootState } from 'app/store/store';
import { omit } from 'es-toolkit/compat';
import { imageUploadedClientSide } from 'features/gallery/store/actions';
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';

View File

@@ -1,5 +1,5 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppStartListening } from 'app/store/store';
import { bboxSyncedToOptimalDimension, rgRefImageModelChanged } from 'features/controlLayers/store/canvasSlice';
import { buildSelectIsStaging, selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { loraDeleted } from 'features/controlLayers/store/lorasSlice';

View File

@@ -1,6 +1,5 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppDispatch, RootState } from 'app/store/store';
import type { AppDispatch, AppStartListening, RootState } from 'app/store/store';
import { controlLayerModelChanged, rgRefImageModelChanged } from 'features/controlLayers/store/canvasSlice';
import { loraDeleted } from 'features/controlLayers/store/lorasSlice';
import {

View File

@@ -1,4 +1,4 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppStartListening } from 'app/store/store';
import { isNil } from 'es-toolkit';
import { bboxHeightChanged, bboxWidthChanged } from 'features/controlLayers/store/canvasSlice';
import { buildSelectIsStaging, selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';

View File

@@ -1,8 +1,8 @@
import { objectEquals } from '@observ33r/object-equals';
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { $baseUrl } from 'app/store/nanostores/baseUrl';
import type { AppStartListening } from 'app/store/store';
import { atom } from 'nanostores';
import { api } from 'services/api';
import { modelsApi } from 'services/api/endpoints/models';

View File

@@ -1,35 +1,46 @@
import type { ThunkDispatch, UnknownAction } from '@reduxjs/toolkit';
import { autoBatchEnhancer, combineReducers, configureStore } from '@reduxjs/toolkit';
import type { ThunkDispatch, TypedStartListening, UnknownAction } from '@reduxjs/toolkit';
import { addListener, combineReducers, configureStore, createListenerMiddleware } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import { idbKeyValDriver } from 'app/store/enhancers/reduxRemember/driver';
import { errorHandler } from 'app/store/enhancers/reduxRemember/errors';
import { addAdHocPostProcessingRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/addAdHocPostProcessingRequestedListener';
import { addAnyEnqueuedListener } from 'app/store/middleware/listenerMiddleware/listeners/anyEnqueued';
import { addAppConfigReceivedListener } from 'app/store/middleware/listenerMiddleware/listeners/appConfigReceived';
import { addAppStartedListener } from 'app/store/middleware/listenerMiddleware/listeners/appStarted';
import { addBatchEnqueuedListener } from 'app/store/middleware/listenerMiddleware/listeners/batchEnqueued';
import { addDeleteBoardAndImagesFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/boardAndImagesDeleted';
import { addBoardIdSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/boardIdSelected';
import { addBulkDownloadListeners } from 'app/store/middleware/listenerMiddleware/listeners/bulkDownload';
import { addGetOpenAPISchemaListener } from 'app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema';
import { addImageAddedToBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageAddedToBoard';
import { addImageRemovedFromBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageRemovedFromBoard';
import { addModelSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelSelected';
import { addModelsLoadedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelsLoaded';
import { addSetDefaultSettingsListener } from 'app/store/middleware/listenerMiddleware/listeners/setDefaultSettings';
import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketConnected';
import { deepClone } from 'common/util/deepClone';
import { keys, mergeWith, omit, pick } from 'es-toolkit/compat';
import { changeBoardModalSlice } from 'features/changeBoardModal/store/slice';
import { canvasSettingsPersistConfig, canvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
import { canvasPersistConfig, canvasSlice, canvasUndoableConfig } from 'features/controlLayers/store/canvasSlice';
import {
canvasSessionSlice,
canvasStagingAreaPersistConfig,
} from 'features/controlLayers/store/canvasStagingAreaSlice';
import { lorasPersistConfig, lorasSlice } from 'features/controlLayers/store/lorasSlice';
import { paramsPersistConfig, paramsSlice } from 'features/controlLayers/store/paramsSlice';
import { refImagesPersistConfig, refImagesSlice } from 'features/controlLayers/store/refImagesSlice';
import { dynamicPromptsPersistConfig, dynamicPromptsSlice } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import { galleryPersistConfig, gallerySlice } from 'features/gallery/store/gallerySlice';
import { modelManagerV2PersistConfig, modelManagerV2Slice } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { nodesPersistConfig, nodesSlice, nodesUndoableConfig } from 'features/nodes/store/nodesSlice';
import { workflowLibraryPersistConfig, workflowLibrarySlice } from 'features/nodes/store/workflowLibrarySlice';
import { workflowSettingsPersistConfig, workflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
import { upscalePersistConfig, upscaleSlice } from 'features/parameters/store/upscaleSlice';
import { queueSlice } from 'features/queue/store/queueSlice';
import { stylePresetPersistConfig, stylePresetSlice } from 'features/stylePresets/store/stylePresetSlice';
import { configSlice } from 'features/system/store/configSlice';
import { systemPersistConfig, systemSlice } from 'features/system/store/systemSlice';
import { uiPersistConfig, uiSlice } from 'features/ui/store/uiSlice';
import { changeBoardModalSliceConfig } from 'features/changeBoardModal/store/slice';
import { canvasSettingsSliceConfig } from 'features/controlLayers/store/canvasSettingsSlice';
import { canvasSliceConfig } from 'features/controlLayers/store/canvasSlice';
import { canvasSessionSliceConfig } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { lorasSliceConfig } from 'features/controlLayers/store/lorasSlice';
import { paramsSliceConfig } from 'features/controlLayers/store/paramsSlice';
import { refImagesSliceConfig } from 'features/controlLayers/store/refImagesSlice';
import { dynamicPromptsSliceConfig } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import { gallerySliceConfig } from 'features/gallery/store/gallerySlice';
import { modelManagerSliceConfig } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { nodesSliceConfig } from 'features/nodes/store/nodesSlice';
import { workflowLibrarySliceConfig } from 'features/nodes/store/workflowLibrarySlice';
import { workflowSettingsSliceConfig } from 'features/nodes/store/workflowSettingsSlice';
import { upscaleSliceConfig } from 'features/parameters/store/upscaleSlice';
import { queueSliceConfig } from 'features/queue/store/queueSlice';
import { stylePresetSliceConfig } from 'features/stylePresets/store/stylePresetSlice';
import { configSliceConfig } from 'features/system/store/configSlice';
import { systemSliceConfig } from 'features/system/store/systemSlice';
import { uiSliceConfig } from 'features/ui/store/uiSlice';
import { diff } from 'jsondiffpatch';
import dynamicMiddlewares from 'redux-dynamic-middlewares';
import type { SerializeFunction, UnserializeFunction } from 'redux-remember';
import type { Driver, SerializeFunction, UnserializeFunction } from 'redux-remember';
import { rememberEnhancer, rememberReducer } from 'redux-remember';
import undoable, { newHistory } from 'redux-undo';
import { serializeError } from 'serialize-error';
@@ -37,123 +48,116 @@ import { api } from 'services/api';
import { authToastMiddleware } from 'services/api/authToastMiddleware';
import type { JsonObject } from 'type-fest';
import { STORAGE_PREFIX } from './constants';
import { actionSanitizer } from './middleware/devtools/actionSanitizer';
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
import { listenerMiddleware } from './middleware/listenerMiddleware';
import { addArchivedOrDeletedBoardListener } from './middleware/listenerMiddleware/listeners/addArchivedOrDeletedBoardListener';
import { addImageUploadedFulfilledListener } from './middleware/listenerMiddleware/listeners/imageUploaded';
export const listenerMiddleware = createListenerMiddleware();
const log = logger('system');
const allReducers = {
[api.reducerPath]: api.reducer,
[gallerySlice.name]: gallerySlice.reducer,
[nodesSlice.name]: undoable(nodesSlice.reducer, nodesUndoableConfig),
[systemSlice.name]: systemSlice.reducer,
[configSlice.name]: configSlice.reducer,
[uiSlice.name]: uiSlice.reducer,
[dynamicPromptsSlice.name]: dynamicPromptsSlice.reducer,
[changeBoardModalSlice.name]: changeBoardModalSlice.reducer,
[modelManagerV2Slice.name]: modelManagerV2Slice.reducer,
[queueSlice.name]: queueSlice.reducer,
[canvasSlice.name]: undoable(canvasSlice.reducer, canvasUndoableConfig),
[workflowSettingsSlice.name]: workflowSettingsSlice.reducer,
[upscaleSlice.name]: upscaleSlice.reducer,
[stylePresetSlice.name]: stylePresetSlice.reducer,
[paramsSlice.name]: paramsSlice.reducer,
[canvasSettingsSlice.name]: canvasSettingsSlice.reducer,
[canvasSessionSlice.name]: canvasSessionSlice.reducer,
[lorasSlice.name]: lorasSlice.reducer,
[workflowLibrarySlice.name]: workflowLibrarySlice.reducer,
[refImagesSlice.name]: refImagesSlice.reducer,
// When adding a slice, add the config to the SLICE_CONFIGS object below, then add the reducer to ALL_REDUCERS.
const SLICE_CONFIGS = {
[canvasSessionSliceConfig.slice.reducerPath]: canvasSessionSliceConfig,
[canvasSettingsSliceConfig.slice.reducerPath]: canvasSettingsSliceConfig,
[canvasSliceConfig.slice.reducerPath]: canvasSliceConfig,
[changeBoardModalSliceConfig.slice.reducerPath]: changeBoardModalSliceConfig,
[configSliceConfig.slice.reducerPath]: configSliceConfig,
[dynamicPromptsSliceConfig.slice.reducerPath]: dynamicPromptsSliceConfig,
[gallerySliceConfig.slice.reducerPath]: gallerySliceConfig,
[lorasSliceConfig.slice.reducerPath]: lorasSliceConfig,
[modelManagerSliceConfig.slice.reducerPath]: modelManagerSliceConfig,
[nodesSliceConfig.slice.reducerPath]: nodesSliceConfig,
[paramsSliceConfig.slice.reducerPath]: paramsSliceConfig,
[queueSliceConfig.slice.reducerPath]: queueSliceConfig,
[refImagesSliceConfig.slice.reducerPath]: refImagesSliceConfig,
[stylePresetSliceConfig.slice.reducerPath]: stylePresetSliceConfig,
[systemSliceConfig.slice.reducerPath]: systemSliceConfig,
[uiSliceConfig.slice.reducerPath]: uiSliceConfig,
[upscaleSliceConfig.slice.reducerPath]: upscaleSliceConfig,
[workflowLibrarySliceConfig.slice.reducerPath]: workflowLibrarySliceConfig,
[workflowSettingsSliceConfig.slice.reducerPath]: workflowSettingsSliceConfig,
};
const rootReducer = combineReducers(allReducers);
// TS makes it really hard to dynamically create this object :/ so it's just hardcoded here.
// Remember to wrap undoable reducers in `undoable()`!
const ALL_REDUCERS = {
[api.reducerPath]: api.reducer,
[canvasSessionSliceConfig.slice.reducerPath]: canvasSessionSliceConfig.slice.reducer,
[canvasSettingsSliceConfig.slice.reducerPath]: canvasSettingsSliceConfig.slice.reducer,
// Undoable!
[canvasSliceConfig.slice.reducerPath]: undoable(
canvasSliceConfig.slice.reducer,
canvasSliceConfig.undoableConfig?.reduxUndoOptions
),
[changeBoardModalSliceConfig.slice.reducerPath]: changeBoardModalSliceConfig.slice.reducer,
[configSliceConfig.slice.reducerPath]: configSliceConfig.slice.reducer,
[dynamicPromptsSliceConfig.slice.reducerPath]: dynamicPromptsSliceConfig.slice.reducer,
[gallerySliceConfig.slice.reducerPath]: gallerySliceConfig.slice.reducer,
[lorasSliceConfig.slice.reducerPath]: lorasSliceConfig.slice.reducer,
[modelManagerSliceConfig.slice.reducerPath]: modelManagerSliceConfig.slice.reducer,
// Undoable!
[nodesSliceConfig.slice.reducerPath]: undoable(
nodesSliceConfig.slice.reducer,
nodesSliceConfig.undoableConfig?.reduxUndoOptions
),
[paramsSliceConfig.slice.reducerPath]: paramsSliceConfig.slice.reducer,
[queueSliceConfig.slice.reducerPath]: queueSliceConfig.slice.reducer,
[refImagesSliceConfig.slice.reducerPath]: refImagesSliceConfig.slice.reducer,
[stylePresetSliceConfig.slice.reducerPath]: stylePresetSliceConfig.slice.reducer,
[systemSliceConfig.slice.reducerPath]: systemSliceConfig.slice.reducer,
[uiSliceConfig.slice.reducerPath]: uiSliceConfig.slice.reducer,
[upscaleSliceConfig.slice.reducerPath]: upscaleSliceConfig.slice.reducer,
[workflowLibrarySliceConfig.slice.reducerPath]: workflowLibrarySliceConfig.slice.reducer,
[workflowSettingsSliceConfig.slice.reducerPath]: workflowSettingsSliceConfig.slice.reducer,
};
const rootReducer = combineReducers(ALL_REDUCERS);
const rememberedRootReducer = rememberReducer(rootReducer);
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
export type PersistConfig<T = any> = {
/**
* The name of the slice.
*/
name: keyof typeof allReducers;
/**
* The initial state of the slice.
*/
initialState: T;
/**
* Migrate the state to the current version during rehydration.
* @param state The rehydrated state.
* @returns A correctly-shaped state.
*/
migrate: (state: unknown) => T;
/**
* Keys to omit from the persisted state.
*/
persistDenylist: (keyof T)[];
};
const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = {
[galleryPersistConfig.name]: galleryPersistConfig,
[nodesPersistConfig.name]: nodesPersistConfig,
[systemPersistConfig.name]: systemPersistConfig,
[uiPersistConfig.name]: uiPersistConfig,
[dynamicPromptsPersistConfig.name]: dynamicPromptsPersistConfig,
[modelManagerV2PersistConfig.name]: modelManagerV2PersistConfig,
[canvasPersistConfig.name]: canvasPersistConfig,
[workflowSettingsPersistConfig.name]: workflowSettingsPersistConfig,
[upscalePersistConfig.name]: upscalePersistConfig,
[stylePresetPersistConfig.name]: stylePresetPersistConfig,
[paramsPersistConfig.name]: paramsPersistConfig,
[canvasSettingsPersistConfig.name]: canvasSettingsPersistConfig,
[canvasStagingAreaPersistConfig.name]: canvasStagingAreaPersistConfig,
[lorasPersistConfig.name]: lorasPersistConfig,
[workflowLibraryPersistConfig.name]: workflowLibraryPersistConfig,
[refImagesSlice.name]: refImagesPersistConfig,
};
const unserialize: UnserializeFunction = (data, key) => {
const persistConfig = persistConfigs[key as keyof typeof persistConfigs];
if (!persistConfig) {
const sliceConfig = SLICE_CONFIGS[key as keyof typeof SLICE_CONFIGS];
if (!sliceConfig?.persistConfig) {
throw new Error(`No persist config for slice "${key}"`);
}
const { getInitialState, persistConfig, undoableConfig } = sliceConfig;
let state;
try {
const { initialState, migrate } = persistConfig;
const parsed = JSON.parse(data);
const initialState = getInitialState();
// strip out old keys
const stripped = pick(deepClone(parsed), keys(initialState));
// run (additive) migrations
const migrated = migrate(stripped);
const stripped = pick(deepClone(data), keys(initialState));
/*
* Merge in initial state as default values, covering any missing keys. You might be tempted to use _.defaultsDeep,
* but that merges arrays by index and partial objects by key. Using an identity function as the customizer results
* in behaviour like defaultsDeep, but doesn't overwrite any values that are not undefined in the migrated state.
*/
const transformed = mergeWith(migrated, initialState, (objVal) => objVal);
const unPersistDenylisted = mergeWith(stripped, initialState, (objVal) => objVal);
// run (additive) migrations
const migrated = persistConfig.migrate(unPersistDenylisted);
log.debug(
{
persistedData: parsed,
rehydratedData: transformed,
diff: diff(parsed, transformed) as JsonObject, // this is always serializable
persistedData: data as JsonObject,
rehydratedData: migrated as JsonObject,
diff: diff(data, migrated) as JsonObject,
},
`Rehydrated slice "${key}"`
);
state = transformed;
state = migrated;
} catch (err) {
log.warn(
{ error: serializeError(err as Error) },
`Error rehydrating slice "${key}", falling back to default initial state`
);
state = persistConfig.initialState;
state = getInitialState();
}
// If the slice is undoable, we need to wrap it in a new history - only nodes and canvas are undoable at the moment.
// TODO(psyche): make this automatic & remove the hard-coding for specific slices.
if (key === nodesSlice.name || key === canvasSlice.name) {
// Undoable slices must be wrapped in a history!
if (undoableConfig) {
return newHistory([], state, []);
} else {
return state;
@@ -161,21 +165,30 @@ const unserialize: UnserializeFunction = (data, key) => {
};
const serialize: SerializeFunction = (data, key) => {
const persistConfig = persistConfigs[key as keyof typeof persistConfigs];
if (!persistConfig) {
const sliceConfig = SLICE_CONFIGS[key as keyof typeof SLICE_CONFIGS];
if (!sliceConfig?.persistConfig) {
throw new Error(`No persist config for slice "${key}"`);
}
// Heuristic to determine if the slice is undoable - could just hardcode it in the persistConfig
const isUndoable = 'present' in data && 'past' in data && 'future' in data && '_latestUnfiltered' in data;
const result = omit(isUndoable ? data.present : data, persistConfig.persistDenylist);
const result = omit(
sliceConfig.undoableConfig ? data.present : data,
sliceConfig.persistConfig.persistDenylist ?? []
);
return JSON.stringify(result);
};
export const createStore = (uniqueStoreKey?: string, persist = true) =>
const PERSISTED_KEYS = Object.values(SLICE_CONFIGS)
.filter((sliceConfig) => !!sliceConfig.persistConfig)
.map((sliceConfig) => sliceConfig.slice.reducerPath);
export const createStore = (reduxRememberOptions: { driver: Driver; persistThrottle: number }) =>
configureStore({
reducer: rememberedRootReducer,
middleware: (getDefaultMiddleware) =>
getDefaultMiddleware({
// serializableCheck: false,
// immutableCheck: false,
serializableCheck: import.meta.env.MODE === 'development',
immutableCheck: import.meta.env.MODE === 'development',
})
@@ -185,19 +198,16 @@ export const createStore = (uniqueStoreKey?: string, persist = true) =>
// .concat(getDebugLoggerMiddleware())
.prepend(listenerMiddleware.middleware),
enhancers: (getDefaultEnhancers) => {
const _enhancers = getDefaultEnhancers().concat(autoBatchEnhancer());
if (persist) {
_enhancers.push(
rememberEnhancer(idbKeyValDriver, keys(persistConfigs), {
persistDebounce: 300,
serialize,
unserialize,
prefix: uniqueStoreKey ? `${STORAGE_PREFIX}${uniqueStoreKey}-` : STORAGE_PREFIX,
errorHandler,
})
);
}
return _enhancers;
const enhancers = getDefaultEnhancers();
return enhancers.prepend(
rememberEnhancer(reduxRememberOptions.driver, PERSISTED_KEYS, {
persistThrottle: reduxRememberOptions.persistThrottle,
serialize,
unserialize,
prefix: '',
errorHandler,
})
);
},
devTools: {
actionSanitizer,
@@ -214,7 +224,48 @@ export const createStore = (uniqueStoreKey?: string, persist = true) =>
export type AppStore = ReturnType<typeof createStore>;
export type RootState = ReturnType<AppStore['getState']>;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
export type AppThunkDispatch = ThunkDispatch<RootState, any, UnknownAction>;
export type AppDispatch = ReturnType<typeof createStore>['dispatch'];
export type AppGetState = ReturnType<typeof createStore>['getState'];
export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
export const addAppListener = addListener.withTypes<RootState, AppDispatch>();
const startAppListening = listenerMiddleware.startListening as AppStartListening;
addImageUploadedFulfilledListener(startAppListening);
// Image deleted
addDeleteBoardAndImagesFulfilledListener(startAppListening);
// User Invoked
addAnyEnqueuedListener(startAppListening);
addBatchEnqueuedListener(startAppListening);
// Socket.IO
addSocketConnectedEventListener(startAppListening);
// Gallery bulk download
addBulkDownloadListeners(startAppListening);
// Boards
addImageAddedToBoardFulfilledListener(startAppListening);
addImageRemovedFromBoardFulfilledListener(startAppListening);
addBoardIdSelectedListener(startAppListening);
addArchivedOrDeletedBoardListener(startAppListening);
// Node schemas
addGetOpenAPISchemaListener(startAppListening);
// Models
addModelSelectedListener(startAppListening);
// app startup
addAppStartedListener(startAppListening);
addModelsLoadedListener(startAppListening);
addAppConfigReceivedListener(startAppListening);
// Ad-hoc upscale workflwo
addAdHocPostProcessingRequestedListener(startAppListening);
addSetDefaultSettingsListener(startAppListening);

View File

@@ -0,0 +1,46 @@
import type { Slice } from '@reduxjs/toolkit';
import type { UndoableOptions } from 'redux-undo';
import type { ZodType } from 'zod';
type StateFromSlice<T extends Slice> = T extends Slice<infer U> ? U : never;
export type SliceConfig<T extends Slice> = {
/**
* The redux slice (return of createSlice).
*/
slice: T;
/**
* The zod schema for the slice.
*/
schema: ZodType<StateFromSlice<T>>;
/**
* A function that returns the initial state of the slice.
*/
getInitialState: () => StateFromSlice<T>;
/**
* The optional persist configuration for this slice. If omitted, the slice will not be persisted.
*/
persistConfig?: {
/**
* Migrate the state to the current version during rehydration. This method should throw an error if the migration
* fails.
*
* @param state The rehydrated state.
* @returns A correctly-shaped state.
*/
migrate: (state: unknown) => StateFromSlice<T>;
/**
* Keys to omit from the persisted state.
*/
persistDenylist?: (keyof StateFromSlice<T>)[];
};
/**
* The optional undoable configuration for this slice. If omitted, the slice will not be undoable.
*/
undoableConfig?: {
/**
* The options to be passed into redux-undo.
*/
reduxUndoOptions: UndoableOptions<StateFromSlice<T>>;
};
};

View File

@@ -1,130 +1,299 @@
import type { FilterType } from 'features/controlLayers/store/filters';
import type { ParameterPrecision, ParameterScheduler } from 'features/parameters/types/parameterSchemas';
import type { TabName } from 'features/ui/store/uiTypes';
import { zFilterType } from 'features/controlLayers/store/filters';
import { zParameterPrecision, zParameterScheduler } from 'features/parameters/types/parameterSchemas';
import { zTabName } from 'features/ui/store/uiTypes';
import type { PartialDeep } from 'type-fest';
import z from 'zod';
/**
* A disable-able application feature
*/
export type AppFeature =
| 'faceRestore'
| 'upscaling'
| 'lightbox'
| 'modelManager'
| 'githubLink'
| 'discordLink'
| 'bugLink'
| 'aboutModal'
| 'localization'
| 'consoleLogging'
| 'dynamicPrompting'
| 'batches'
| 'syncModels'
| 'multiselect'
| 'pauseQueue'
| 'resumeQueue'
| 'invocationCache'
| 'modelCache'
| 'bulkDownload'
| 'starterModels'
| 'hfToken'
| 'retryQueueItem'
| 'cancelAndClearAll'
| 'chatGPT4oHigh'
| 'modelRelationships';
/**
* A disable-able Stable Diffusion feature
*/
export type SDFeature =
| 'controlNet'
| 'noise'
| 'perlinNoise'
| 'noiseThreshold'
| 'variation'
| 'symmetry'
| 'seamless'
| 'hires'
| 'lora'
| 'embedding'
| 'vae'
| 'hrf';
const zAppFeature = z.enum([
'faceRestore',
'upscaling',
'lightbox',
'modelManager',
'githubLink',
'discordLink',
'bugLink',
'aboutModal',
'localization',
'consoleLogging',
'dynamicPrompting',
'batches',
'syncModels',
'multiselect',
'pauseQueue',
'resumeQueue',
'invocationCache',
'modelCache',
'bulkDownload',
'starterModels',
'hfToken',
'retryQueueItem',
'cancelAndClearAll',
'chatGPT4oHigh',
'modelRelationships',
]);
export type AppFeature = z.infer<typeof zAppFeature>;
export type NumericalParameterConfig = {
initial: number;
sliderMin: number;
sliderMax: number;
numberInputMin: number;
numberInputMax: number;
fineStep: number;
coarseStep: number;
};
const zSDFeature = z.enum([
'controlNet',
'noise',
'perlinNoise',
'noiseThreshold',
'variation',
'symmetry',
'seamless',
'hires',
'lora',
'embedding',
'vae',
'hrf',
]);
export type SDFeature = z.infer<typeof zSDFeature>;
const zNumericalParameterConfig = z.object({
initial: z.number().default(512),
sliderMin: z.number().default(64),
sliderMax: z.number().default(1536),
numberInputMin: z.number().default(64),
numberInputMax: z.number().default(4096),
fineStep: z.number().default(8),
coarseStep: z.number().default(64),
});
/**
* Configuration options for the InvokeAI UI.
* Distinct from system settings which may be changed inside the app.
*/
export type AppConfig = {
export const zAppConfig = z.object({
/**
* Whether or not we should update image urls when image loading errors
*/
shouldUpdateImagesOnConnect: boolean;
shouldFetchMetadataFromApi: boolean;
shouldUpdateImagesOnConnect: z.boolean(),
shouldFetchMetadataFromApi: z.boolean(),
/**
* Sets a size limit for outputs on the upscaling tab. This is a maximum dimension, so the actual max number of pixels
* will be the square of this value.
*/
maxUpscaleDimension?: number;
allowPrivateBoards: boolean;
allowPrivateStylePresets: boolean;
allowClientSideUpload: boolean;
allowPublishWorkflows: boolean;
allowPromptExpansion: boolean;
disabledTabs: TabName[];
disabledFeatures: AppFeature[];
disabledSDFeatures: SDFeature[];
nodesAllowlist: string[] | undefined;
nodesDenylist: string[] | undefined;
metadataFetchDebounce?: number;
workflowFetchDebounce?: number;
isLocal?: boolean;
shouldShowCredits: boolean;
sd: {
defaultModel?: string;
disabledControlNetModels: string[];
disabledControlNetProcessors: FilterType[];
maxUpscaleDimension: z.number().optional(),
allowPrivateBoards: z.boolean(),
allowPrivateStylePresets: z.boolean(),
allowClientSideUpload: z.boolean(),
allowPublishWorkflows: z.boolean(),
allowPromptExpansion: z.boolean(),
disabledTabs: z.array(zTabName),
disabledFeatures: z.array(zAppFeature),
disabledSDFeatures: z.array(zSDFeature),
nodesAllowlist: z.array(z.string()).optional(),
nodesDenylist: z.array(z.string()).optional(),
metadataFetchDebounce: z.number().int().optional(),
workflowFetchDebounce: z.number().int().optional(),
isLocal: z.boolean().optional(),
shouldShowCredits: z.boolean().optional(),
sd: z.object({
defaultModel: z.string().optional(),
disabledControlNetModels: z.array(z.string()),
disabledControlNetProcessors: z.array(zFilterType),
// Core parameters
iterations: NumericalParameterConfig;
width: NumericalParameterConfig; // initial value comes from model
height: NumericalParameterConfig; // initial value comes from model
steps: NumericalParameterConfig;
guidance: NumericalParameterConfig;
cfgRescaleMultiplier: NumericalParameterConfig;
img2imgStrength: NumericalParameterConfig;
scheduler?: ParameterScheduler;
vaePrecision?: ParameterPrecision;
iterations: zNumericalParameterConfig,
width: zNumericalParameterConfig,
height: zNumericalParameterConfig,
steps: zNumericalParameterConfig,
guidance: zNumericalParameterConfig,
cfgRescaleMultiplier: zNumericalParameterConfig,
img2imgStrength: zNumericalParameterConfig,
scheduler: zParameterScheduler.optional(),
vaePrecision: zParameterPrecision.optional(),
// Canvas
boundingBoxHeight: NumericalParameterConfig; // initial value comes from model
boundingBoxWidth: NumericalParameterConfig; // initial value comes from model
scaledBoundingBoxHeight: NumericalParameterConfig; // initial value comes from model
scaledBoundingBoxWidth: NumericalParameterConfig; // initial value comes from model
canvasCoherenceStrength: NumericalParameterConfig;
canvasCoherenceEdgeSize: NumericalParameterConfig;
infillTileSize: NumericalParameterConfig;
infillPatchmatchDownscaleSize: NumericalParameterConfig;
boundingBoxHeight: zNumericalParameterConfig,
boundingBoxWidth: zNumericalParameterConfig,
scaledBoundingBoxHeight: zNumericalParameterConfig,
scaledBoundingBoxWidth: zNumericalParameterConfig,
canvasCoherenceStrength: zNumericalParameterConfig,
canvasCoherenceEdgeSize: zNumericalParameterConfig,
infillTileSize: zNumericalParameterConfig,
infillPatchmatchDownscaleSize: zNumericalParameterConfig,
// Misc advanced
clipSkip: NumericalParameterConfig; // slider and input max are ignored for this, because the values depend on the model
maskBlur: NumericalParameterConfig;
hrfStrength: NumericalParameterConfig;
dynamicPrompts: {
maxPrompts: NumericalParameterConfig;
};
ca: {
weight: NumericalParameterConfig;
};
};
flux: {
guidance: NumericalParameterConfig;
};
};
clipSkip: zNumericalParameterConfig, // slider and input max are ignored for this, because the values depend on the model
maskBlur: zNumericalParameterConfig,
hrfStrength: zNumericalParameterConfig,
dynamicPrompts: z.object({
maxPrompts: zNumericalParameterConfig,
}),
ca: z.object({
weight: zNumericalParameterConfig,
}),
}),
flux: z.object({
guidance: zNumericalParameterConfig,
}),
});
export type AppConfig = z.infer<typeof zAppConfig>;
export type PartialAppConfig = PartialDeep<AppConfig>;
export const getDefaultAppConfig = (): AppConfig => ({
isLocal: true,
shouldUpdateImagesOnConnect: false,
shouldFetchMetadataFromApi: false,
allowPrivateBoards: false,
allowPrivateStylePresets: false,
allowClientSideUpload: false,
allowPublishWorkflows: false,
allowPromptExpansion: false,
shouldShowCredits: false,
disabledTabs: [],
disabledFeatures: ['lightbox', 'faceRestore', 'batches'] satisfies AppFeature[],
disabledSDFeatures: ['variation', 'symmetry', 'hires', 'perlinNoise', 'noiseThreshold'] satisfies SDFeature[],
sd: {
disabledControlNetModels: [],
disabledControlNetProcessors: [],
iterations: {
initial: 1,
sliderMin: 1,
sliderMax: 1000,
numberInputMin: 1,
numberInputMax: 10000,
fineStep: 1,
coarseStep: 1,
},
width: zNumericalParameterConfig.parse({}), // initial value comes from model
height: zNumericalParameterConfig.parse({}), // initial value comes from model
boundingBoxWidth: zNumericalParameterConfig.parse({}), // initial value comes from model
boundingBoxHeight: zNumericalParameterConfig.parse({}), // initial value comes from model
scaledBoundingBoxWidth: zNumericalParameterConfig.parse({}), // initial value comes from model
scaledBoundingBoxHeight: zNumericalParameterConfig.parse({}), // initial value comes from model
scheduler: 'dpmpp_3m_k' as const,
vaePrecision: 'fp32' as const,
steps: {
initial: 30,
sliderMin: 1,
sliderMax: 100,
numberInputMin: 1,
numberInputMax: 500,
fineStep: 1,
coarseStep: 1,
},
guidance: {
initial: 7,
sliderMin: 1,
sliderMax: 20,
numberInputMin: 1,
numberInputMax: 200,
fineStep: 0.1,
coarseStep: 0.5,
},
img2imgStrength: {
initial: 0.7,
sliderMin: 0,
sliderMax: 1,
numberInputMin: 0,
numberInputMax: 1,
fineStep: 0.01,
coarseStep: 0.05,
},
canvasCoherenceStrength: {
initial: 0.3,
sliderMin: 0,
sliderMax: 1,
numberInputMin: 0,
numberInputMax: 1,
fineStep: 0.01,
coarseStep: 0.05,
},
hrfStrength: {
initial: 0.45,
sliderMin: 0,
sliderMax: 1,
numberInputMin: 0,
numberInputMax: 1,
fineStep: 0.01,
coarseStep: 0.05,
},
canvasCoherenceEdgeSize: {
initial: 16,
sliderMin: 0,
sliderMax: 128,
numberInputMin: 0,
numberInputMax: 1024,
fineStep: 8,
coarseStep: 16,
},
cfgRescaleMultiplier: {
initial: 0,
sliderMin: 0,
sliderMax: 0.99,
numberInputMin: 0,
numberInputMax: 0.99,
fineStep: 0.05,
coarseStep: 0.1,
},
clipSkip: {
initial: 0,
sliderMin: 0,
sliderMax: 12, // determined by model selection, unused in practice
numberInputMin: 0,
numberInputMax: 12, // determined by model selection, unused in practice
fineStep: 1,
coarseStep: 1,
},
infillPatchmatchDownscaleSize: {
initial: 1,
sliderMin: 1,
sliderMax: 10,
numberInputMin: 1,
numberInputMax: 10,
fineStep: 1,
coarseStep: 1,
},
infillTileSize: {
initial: 32,
sliderMin: 16,
sliderMax: 64,
numberInputMin: 16,
numberInputMax: 256,
fineStep: 1,
coarseStep: 1,
},
maskBlur: {
initial: 16,
sliderMin: 0,
sliderMax: 128,
numberInputMin: 0,
numberInputMax: 512,
fineStep: 1,
coarseStep: 1,
},
ca: {
weight: {
initial: 1,
sliderMin: 0,
sliderMax: 2,
numberInputMin: -1,
numberInputMax: 2,
fineStep: 0.01,
coarseStep: 0.05,
},
},
dynamicPrompts: {
maxPrompts: {
initial: 100,
sliderMin: 1,
sliderMax: 1000,
numberInputMin: 1,
numberInputMax: 10000,
fineStep: 1,
coarseStep: 10,
},
},
},
flux: {
guidance: {
initial: 4,
sliderMin: 2,
sliderMax: 6,
numberInputMin: 1,
numberInputMax: 20,
fineStep: 0.1,
coarseStep: 0.5,
},
},
});

View File

@@ -1,6 +1,8 @@
import { MenuItem } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { allEntitiesDeleted } from 'features/controlLayers/store/canvasSlice';
import { canvasReset } from 'features/controlLayers/store/actions';
import { inpaintMaskAdded } from 'features/controlLayers/store/canvasSlice';
import { $canvasManager } from 'features/controlLayers/store/ephemeral';
import { paramsReset } from 'features/controlLayers/store/paramsSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
@@ -11,7 +13,9 @@ export const SessionMenuItems = memo(() => {
const dispatch = useAppDispatch();
const resetCanvasLayers = useCallback(() => {
dispatch(allEntitiesDeleted());
dispatch(canvasReset());
dispatch(inpaintMaskAdded({ isSelected: true, isBookmarked: true }));
$canvasManager.get()?.stage.fitBboxToStage();
}, [dispatch]);
const resetGenerationSettings = useCallback(() => {
dispatch(paramsReset());

View File

@@ -1,11 +0,0 @@
import { clearIdbKeyValStore } from 'app/store/enhancers/reduxRemember/driver';
import { useCallback } from 'react';
export const useClearStorage = () => {
const clearStorage = useCallback(() => {
clearIdbKeyValStore();
localStorage.clear();
}, []);
return clearStorage;
};

View File

@@ -139,4 +139,13 @@ export const useGlobalHotkeys = () => {
},
dependencies: [getState, deleteImageModalApi],
});
useRegisteredHotkeys({
id: 'toggleViewer',
category: 'viewer',
callback: () => {
navigationApi.toggleViewerPanel();
},
dependencies: [],
});
};

View File

@@ -1,6 +0,0 @@
import type { ChangeBoardModalState } from './types';
export const initialState: ChangeBoardModalState = {
isModalOpen: false,
image_names: [],
};

View File

@@ -1,12 +1,20 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import z from 'zod';
import { initialState } from './initialState';
const zChangeBoardModalState = z.object({
isModalOpen: z.boolean().default(false),
image_names: z.array(z.string()).default(() => []),
});
type ChangeBoardModalState = z.infer<typeof zChangeBoardModalState>;
export const changeBoardModalSlice = createSlice({
const getInitialState = (): ChangeBoardModalState => zChangeBoardModalState.parse({});
const slice = createSlice({
name: 'changeBoardModal',
initialState,
initialState: getInitialState(),
reducers: {
isModalOpenChanged: (state, action: PayloadAction<boolean>) => {
state.isModalOpen = action.payload;
@@ -21,6 +29,12 @@ export const changeBoardModalSlice = createSlice({
},
});
export const { isModalOpenChanged, imagesToChangeSelected, changeBoardReset } = changeBoardModalSlice.actions;
export const { isModalOpenChanged, imagesToChangeSelected, changeBoardReset } = slice.actions;
export const selectChangeBoardModalSlice = (state: RootState) => state.changeBoardModal;
export const changeBoardModalSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zChangeBoardModalState,
getInitialState,
};

View File

@@ -1,4 +0,0 @@
export type ChangeBoardModalState = {
isModalOpen: boolean;
image_names: string[];
};

View File

@@ -165,9 +165,9 @@ export const CanvasEntityGroupList = memo(({ isSelected, type, children, entityI
<Spacer />
</Flex>
{type === 'raster_layer' && <RasterLayerExportPSDButton />}
<CanvasEntityMergeVisibleButton type={type} />
<CanvasEntityTypeIsHiddenToggle type={type} />
{type === 'raster_layer' && <RasterLayerExportPSDButton />}
<CanvasEntityAddOfTypeButton type={type} />
</Flex>
<Collapse in={collapse.isTrue} style={fixTooltipCloseOnScrollStyles}>

View File

@@ -42,7 +42,7 @@ const DEFAULT_CONFIG: CanvasStageModuleConfig = {
SCALE_FACTOR: 0.999,
FIT_LAYERS_TO_STAGE_PADDING_PX: 48,
SCALE_SNAP_POINTS: [0.25, 0.5, 0.75, 1, 1.5, 2, 3, 4, 5],
SCALE_SNAP_TOLERANCE: 0.05,
SCALE_SNAP_TOLERANCE: 0.02,
};
export class CanvasStageModule extends CanvasModuleBase {
@@ -366,11 +366,22 @@ export class CanvasStageModule extends CanvasModuleBase {
if (deltaT > 300) {
dynamicScaleFactor = this.config.SCALE_FACTOR + (1 - this.config.SCALE_FACTOR) / 2;
} else if (deltaT < 300) {
dynamicScaleFactor = this.config.SCALE_FACTOR + (1 - this.config.SCALE_FACTOR) * (deltaT / 200);
// Ensure dynamic scale factor stays below 1 to maintain zoom-out direction - if it goes over, we could end up
// zooming in the wrong direction with small scroll amounts
const maxScaleFactor = 0.9999;
dynamicScaleFactor = Math.min(
this.config.SCALE_FACTOR + (1 - this.config.SCALE_FACTOR) * (deltaT / 200),
maxScaleFactor
);
}
// Update the intended scale based on the last intended scale, creating a continuous zoom feel
const newIntendedScale = this._intendedScale * dynamicScaleFactor ** scrollAmount;
// Handle the sign explicitly to prevent direction reversal with small scroll amounts
const scaleFactor =
scrollAmount > 0
? dynamicScaleFactor ** Math.abs(scrollAmount)
: (1 / dynamicScaleFactor) ** Math.abs(scrollAmount);
const newIntendedScale = this._intendedScale * scaleFactor;
this._intendedScale = this.constrainScale(newIntendedScale);
// Pass control to the snapping logic
@@ -397,6 +408,9 @@ export class CanvasStageModule extends CanvasModuleBase {
// User has scrolled far enough to break the snap
this._activeSnapPoint = null;
this._applyScale(this._intendedScale, center);
} else {
// Reset intended scale to prevent drift while snapped
this._intendedScale = this._activeSnapPoint;
}
// Else, do nothing - we remain snapped at the current scale, creating a "dead zone"
return;

View File

@@ -1,7 +1,7 @@
import { $alt, $ctrl, $meta, $shift } from '@invoke-ai/ui-library';
import type { Selector } from '@reduxjs/toolkit';
import { addAppListener } from 'app/store/middleware/listenerMiddleware';
import type { AppStore, RootState } from 'app/store/store';
import { addAppListener } from 'app/store/store';
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer';
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';

View File

@@ -1,6 +1,7 @@
import type { PayloadAction, Selector } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { zRgbaColor } from 'features/controlLayers/store/types';
import { z } from 'zod';
@@ -11,32 +12,32 @@ const zCanvasSettingsState = z.object({
/**
* Whether to show HUD (Heads-Up Display) on the canvas.
*/
showHUD: z.boolean().default(true),
showHUD: z.boolean(),
/**
* Whether to clip lines and shapes to the generation bounding box. If disabled, lines and shapes will be clipped to
* the canvas bounds.
*/
clipToBbox: z.boolean().default(false),
clipToBbox: z.boolean(),
/**
* Whether to show a dynamic grid on the canvas. If disabled, a checkerboard pattern will be shown instead.
*/
dynamicGrid: z.boolean().default(false),
dynamicGrid: z.boolean(),
/**
* Whether to invert the scroll direction when adjusting the brush or eraser width with the scroll wheel.
*/
invertScrollForToolWidth: z.boolean().default(false),
invertScrollForToolWidth: z.boolean(),
/**
* The width of the brush tool.
*/
brushWidth: z.int().gt(0).default(50),
brushWidth: z.int().gt(0),
/**
* The width of the eraser tool.
*/
eraserWidth: z.int().gt(0).default(50),
eraserWidth: z.int().gt(0),
/**
* The color to use when drawing lines or filling shapes.
*/
color: zRgbaColor.default({ r: 31, g: 160, b: 224, a: 1 }), // invokeBlue.500
color: zRgbaColor,
/**
* Whether to composite inpainted/outpainted regions back onto the source image when saving canvas generations.
*
@@ -44,57 +45,77 @@ const zCanvasSettingsState = z.object({
*
* When `sendToCanvas` is disabled, this setting is ignored, masked regions will always be composited.
*/
outputOnlyMaskedRegions: z.boolean().default(true),
outputOnlyMaskedRegions: z.boolean(),
/**
* Whether to automatically process the operations like filtering and auto-masking.
*/
autoProcess: z.boolean().default(true),
autoProcess: z.boolean(),
/**
* The snap-to-grid setting for the canvas.
*/
snapToGrid: z.boolean().default(true),
snapToGrid: z.boolean(),
/**
* Whether to show progress on the canvas when generating images.
*/
showProgressOnCanvas: z.boolean().default(true),
showProgressOnCanvas: z.boolean(),
/**
* Whether to show the bounding box overlay on the canvas.
*/
bboxOverlay: z.boolean().default(false),
bboxOverlay: z.boolean(),
/**
* Whether to preserve the masked region instead of inpainting it.
*/
preserveMask: z.boolean().default(false),
preserveMask: z.boolean(),
/**
* Whether to show only raster layers while staging.
*/
isolatedStagingPreview: z.boolean().default(true),
isolatedStagingPreview: z.boolean(),
/**
* Whether to show only the selected layer while filtering, transforming, or doing other operations.
*/
isolatedLayerPreview: z.boolean().default(true),
isolatedLayerPreview: z.boolean(),
/**
* Whether to use pressure sensitivity for the brush and eraser tool when a pen device is used.
*/
pressureSensitivity: z.boolean().default(true),
pressureSensitivity: z.boolean(),
/**
* Whether to show the rule of thirds composition guide overlay on the canvas.
*/
ruleOfThirds: z.boolean().default(false),
ruleOfThirds: z.boolean(),
/**
* Whether to save all staging images to the gallery instead of keeping them as intermediate images.
*/
saveAllImagesToGallery: z.boolean().default(false),
saveAllImagesToGallery: z.boolean(),
/**
* The auto-switch mode for the canvas staging area.
*/
stagingAreaAutoSwitch: zAutoSwitchMode.default('switch_on_start'),
stagingAreaAutoSwitch: zAutoSwitchMode,
});
type CanvasSettingsState = z.infer<typeof zCanvasSettingsState>;
const getInitialState = () => zCanvasSettingsState.parse({});
const getInitialState = (): CanvasSettingsState => ({
showHUD: true,
clipToBbox: false,
dynamicGrid: false,
invertScrollForToolWidth: false,
brushWidth: 50,
eraserWidth: 50,
color: { r: 31, g: 160, b: 224, a: 1 }, // invokeBlue.500
outputOnlyMaskedRegions: true,
autoProcess: true,
snapToGrid: true,
showProgressOnCanvas: true,
bboxOverlay: false,
preserveMask: false,
isolatedStagingPreview: true,
isolatedLayerPreview: true,
pressureSensitivity: true,
ruleOfThirds: false,
saveAllImagesToGallery: false,
stagingAreaAutoSwitch: 'switch_on_start',
});
export const canvasSettingsSlice = createSlice({
const slice = createSlice({
name: 'canvasSettings',
initialState: getInitialState(),
reducers: {
@@ -184,18 +205,15 @@ export const {
settingsRuleOfThirdsToggled,
settingsSaveAllImagesToGalleryToggled,
settingsStagingAreaAutoSwitchChanged,
} = canvasSettingsSlice.actions;
} = slice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrate = (state: any): any => {
return state;
};
export const canvasSettingsPersistConfig: PersistConfig<CanvasSettingsState> = {
name: canvasSettingsSlice.name,
initialState: getInitialState(),
migrate,
persistDenylist: [],
export const canvasSettingsSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zCanvasSettingsState,
getInitialState,
persistConfig: {
migrate: (state) => zCanvasSettingsState.parse(state),
},
};
export const selectCanvasSettingsSlice = (s: RootState) => s.canvasSettings;

View File

@@ -1,6 +1,6 @@
import type { PayloadAction, UnknownAction } from '@reduxjs/toolkit';
import { createSlice, isAnyOf } from '@reduxjs/toolkit';
import type { PersistConfig } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { moveOneToEnd, moveOneToStart, moveToEnd, moveToStart } from 'common/util/arrayUtils';
import { deepClone } from 'common/util/deepClone';
import { roundDownToMultiple, roundToMultiple } from 'common/util/roundDownToMultiple';
@@ -80,6 +80,7 @@ import {
isFLUXReduxConfig,
isImagenAspectRatioID,
isIPAdapterConfig,
zCanvasState,
} from './types';
import {
converters,
@@ -95,7 +96,7 @@ import {
initialT2IAdapter,
} from './util';
export const canvasSlice = createSlice({
const slice = createSlice({
name: 'canvas',
initialState: getInitialCanvasState(),
reducers: {
@@ -1618,7 +1619,6 @@ export const {
entityArrangedToBack,
entityOpacityChanged,
entitiesReordered,
allEntitiesDeleted,
allEntitiesOfTypeIsHiddenToggled,
allNonRasterLayersIsHiddenToggled,
// bbox
@@ -1676,19 +1676,7 @@ export const {
inpaintMaskDenoiseLimitChanged,
inpaintMaskDenoiseLimitDeleted,
// inpaintMaskRecalled,
} = canvasSlice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrate = (state: any): any => {
return state;
};
export const canvasPersistConfig: PersistConfig<CanvasState> = {
name: canvasSlice.name,
initialState: getInitialCanvasState(),
migrate,
persistDenylist: [],
};
} = slice.actions;
const syncScaledSize = (state: CanvasState) => {
if (API_BASE_MODELS.includes(state.bbox.modelBase)) {
@@ -1711,14 +1699,14 @@ const syncScaledSize = (state: CanvasState) => {
let filter = true;
export const canvasUndoableConfig: UndoableOptions<CanvasState, UnknownAction> = {
const canvasUndoableConfig: UndoableOptions<CanvasState, UnknownAction> = {
limit: 64,
undoType: canvasUndo.type,
redoType: canvasRedo.type,
clearHistoryType: canvasClearHistory.type,
filter: (action, _state, _history) => {
// Ignore all actions from other slices
if (!action.type.startsWith(canvasSlice.name)) {
if (!action.type.startsWith(slice.name)) {
return false;
}
// Throttle rapid actions of the same type
@@ -1729,6 +1717,18 @@ export const canvasUndoableConfig: UndoableOptions<CanvasState, UnknownAction> =
// debug: import.meta.env.MODE === 'development',
};
export const canvasSliceConfig: SliceConfig<typeof slice> = {
slice,
getInitialState: getInitialCanvasState,
schema: zCanvasState,
persistConfig: {
migrate: (state) => zCanvasState.parse(state),
},
undoableConfig: {
reduxUndoOptions: canvasUndoableConfig,
},
};
const doNotGroupMatcher = isAnyOf(entityBrushLineAdded, entityEraserLineAdded, entityRectAdded);
// Store rapid actions of the same type at most once every x time.

View File

@@ -1,27 +1,29 @@
import { createSelector, createSlice, type PayloadAction } from '@reduxjs/toolkit';
import { EMPTY_ARRAY } from 'app/store/constants';
import type { PersistConfig, RootState } from 'app/store/store';
import type { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { deepClone } from 'common/util/deepClone';
import type { SliceConfig } from 'app/store/types';
import { isPlainObject } from 'es-toolkit';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { useMemo } from 'react';
import { queueApi } from 'services/api/endpoints/queue';
import { assert } from 'tsafe';
import z from 'zod';
type CanvasStagingAreaState = {
_version: 1;
canvasSessionId: string;
canvasDiscardedQueueItems: number[];
};
const zCanvasStagingAreaState = z.object({
_version: z.literal(1),
canvasSessionId: z.string(),
canvasDiscardedQueueItems: z.array(z.number().int()),
});
type CanvasStagingAreaState = z.infer<typeof zCanvasStagingAreaState>;
const INITIAL_STATE: CanvasStagingAreaState = {
const getInitialState = (): CanvasStagingAreaState => ({
_version: 1,
canvasSessionId: getPrefixedId('canvas'),
canvasDiscardedQueueItems: [],
};
});
const getInitialState = (): CanvasStagingAreaState => deepClone(INITIAL_STATE);
export const canvasSessionSlice = createSlice({
const slice = createSlice({
name: 'canvasSession',
initialState: getInitialState(),
reducers: {
@@ -48,26 +50,26 @@ export const canvasSessionSlice = createSlice({
},
});
export const { canvasSessionReset, canvasQueueItemDiscarded } = canvasSessionSlice.actions;
export const { canvasSessionReset, canvasQueueItemDiscarded } = slice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrate = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
state.canvasSessionId = state.canvasSessionId ?? getPrefixedId('canvas');
}
export const canvasSessionSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zCanvasStagingAreaState,
getInitialState,
persistConfig: {
migrate: (state) => {
assert(isPlainObject(state));
if (!('_version' in state)) {
state._version = 1;
state.canvasSessionId = state.canvasSessionId ?? getPrefixedId('canvas');
}
return state;
return zCanvasStagingAreaState.parse(state);
},
},
};
export const canvasStagingAreaPersistConfig: PersistConfig<CanvasStagingAreaState> = {
name: canvasSessionSlice.name,
initialState: getInitialState(),
migrate,
persistDenylist: [],
};
export const selectCanvasSessionSlice = (s: RootState) => s[canvasSessionSlice.name];
export const selectCanvasSessionSlice = (s: RootState) => s[slice.name];
export const selectCanvasSessionId = createSelector(selectCanvasSessionSlice, ({ canvasSessionId }) => canvasSessionId);
const selectDiscardedItems = createSelector(

View File

@@ -166,7 +166,7 @@ const _zFilterConfig = z.discriminatedUnion('type', [
]);
export type FilterConfig = z.infer<typeof _zFilterConfig>;
const zFilterType = z.enum([
export const zFilterType = z.enum([
'adjust_image',
'canny_edge_detection',
'color_map',

View File

@@ -1,30 +1,32 @@
import { createSelector, createSlice, type PayloadAction } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { paramsReset } from 'features/controlLayers/store/paramsSlice';
import type { LoRA } from 'features/controlLayers/store/types';
import { type LoRA, zLoRA } from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common';
import type { LoRAModelConfig } from 'services/api/types';
import { v4 as uuidv4 } from 'uuid';
import z from 'zod';
type LoRAsState = {
loras: LoRA[];
};
const zLoRAsState = z.object({
loras: z.array(zLoRA),
});
type LoRAsState = z.infer<typeof zLoRAsState>;
const defaultLoRAConfig: Pick<LoRA, 'weight' | 'isEnabled'> = {
weight: 0.75,
isEnabled: true,
};
const initialState: LoRAsState = {
const getInitialState = (): LoRAsState => ({
loras: [],
};
});
const selectLoRA = (state: LoRAsState, id: string) => state.loras.find((lora) => lora.id === id);
export const lorasSlice = createSlice({
const slice = createSlice({
name: 'loras',
initialState,
initialState: getInitialState(),
reducers: {
loraAdded: {
reducer: (state, action: PayloadAction<{ model: LoRAModelConfig; id: string }>) => {
@@ -66,24 +68,21 @@ export const lorasSlice = createSlice({
extraReducers(builder) {
builder.addCase(paramsReset, () => {
// When a new session is requested, clear all LoRAs
return deepClone(initialState);
return getInitialState();
});
},
});
export const { loraAdded, loraRecalled, loraDeleted, loraWeightChanged, loraIsEnabledChanged, loraAllDeleted } =
lorasSlice.actions;
slice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrate = (state: any): any => {
return state;
};
export const lorasPersistConfig: PersistConfig<LoRAsState> = {
name: lorasSlice.name,
initialState,
migrate,
persistDenylist: [],
export const lorasSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zLoRAsState,
getInitialState,
persistConfig: {
migrate: (state) => zLoRAsState.parse(state),
},
};
export const selectLoRAsSlice = (state: RootState) => state.loras;

View File

@@ -1,6 +1,7 @@
import type { PayloadAction, Selector } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { deepClone } from 'common/util/deepClone';
import { roundDownToMultiple, roundToMultiple } from 'common/util/roundDownToMultiple';
import { clamp } from 'es-toolkit/compat';
@@ -15,6 +16,7 @@ import {
isChatGPT4oAspectRatioID,
isFluxKontextAspectRatioID,
isImagenAspectRatioID,
zParamsState,
} from 'features/controlLayers/store/types';
import { calculateNewSize } from 'features/controlLayers/util/getScaledBoundingBoxDimensions';
import { CLIP_SKIP_MAP } from 'features/parameters/types/constants';
@@ -40,7 +42,7 @@ import { getGridSize, getIsSizeOptimal, getOptimalDimension } from 'features/par
import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/api/endpoints/models';
import { isNonRefinerMainModelConfig } from 'services/api/types';
export const paramsSlice = createSlice({
const slice = createSlice({
name: 'params',
initialState: getInitialParamsState(),
reducers: {
@@ -92,7 +94,12 @@ export const paramsSlice = createSlice({
state,
action: PayloadAction<{ model: ParameterModel | null; previousModel?: ParameterModel | null }>
) => {
const { model, previousModel } = action.payload;
const { previousModel } = action.payload;
const result = zParamsState.shape.model.safeParse(action.payload.model);
if (!result.success) {
return;
}
const model = result.data;
state.model = model;
// If the model base changes (e.g. SD1.5 -> SDXL), we need to change a few things
@@ -111,25 +118,53 @@ export const paramsSlice = createSlice({
},
vaeSelected: (state, action: PayloadAction<ParameterVAEModel | null>) => {
// null is a valid VAE!
state.vae = action.payload;
const result = zParamsState.shape.vae.safeParse(action.payload);
if (!result.success) {
return;
}
state.vae = result.data;
},
fluxVAESelected: (state, action: PayloadAction<ParameterVAEModel | null>) => {
state.fluxVAE = action.payload;
const result = zParamsState.shape.fluxVAE.safeParse(action.payload);
if (!result.success) {
return;
}
state.fluxVAE = result.data;
},
t5EncoderModelSelected: (state, action: PayloadAction<ParameterT5EncoderModel | null>) => {
state.t5EncoderModel = action.payload;
const result = zParamsState.shape.t5EncoderModel.safeParse(action.payload);
if (!result.success) {
return;
}
state.t5EncoderModel = result.data;
},
controlLoRAModelSelected: (state, action: PayloadAction<ParameterControlLoRAModel | null>) => {
state.controlLora = action.payload;
const result = zParamsState.shape.controlLora.safeParse(action.payload);
if (!result.success) {
return;
}
state.controlLora = result.data;
},
clipEmbedModelSelected: (state, action: PayloadAction<ParameterCLIPEmbedModel | null>) => {
state.clipEmbedModel = action.payload;
const result = zParamsState.shape.clipEmbedModel.safeParse(action.payload);
if (!result.success) {
return;
}
state.clipEmbedModel = result.data;
},
clipLEmbedModelSelected: (state, action: PayloadAction<ParameterCLIPLEmbedModel | null>) => {
state.clipLEmbedModel = action.payload;
const result = zParamsState.shape.clipLEmbedModel.safeParse(action.payload);
if (!result.success) {
return;
}
state.clipLEmbedModel = result.data;
},
clipGEmbedModelSelected: (state, action: PayloadAction<ParameterCLIPGEmbedModel | null>) => {
state.clipGEmbedModel = action.payload;
const result = zParamsState.shape.clipGEmbedModel.safeParse(action.payload);
if (!result.success) {
return;
}
state.clipGEmbedModel = result.data;
},
vaePrecisionChanged: (state, action: PayloadAction<ParameterPrecision>) => {
state.vaePrecision = action.payload;
@@ -156,7 +191,11 @@ export const paramsSlice = createSlice({
state.shouldConcatPrompts = action.payload;
},
refinerModelChanged: (state, action: PayloadAction<ParameterSDXLRefinerModel | null>) => {
state.refinerModel = action.payload;
const result = zParamsState.shape.refinerModel.safeParse(action.payload);
if (!result.success) {
return;
}
state.refinerModel = result.data;
},
setRefinerSteps: (state, action: PayloadAction<number>) => {
state.refinerSteps = action.payload;
@@ -397,18 +436,15 @@ export const {
syncedToOptimalDimension,
paramsReset,
} = paramsSlice.actions;
} = slice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrate = (state: any): any => {
return state;
};
export const paramsPersistConfig: PersistConfig<ParamsState> = {
name: paramsSlice.name,
initialState: getInitialParamsState(),
migrate,
persistDenylist: [],
export const paramsSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zParamsState,
getInitialState: getInitialParamsState,
persistConfig: {
migrate: (state) => zParamsState.parse(state),
},
};
export const selectParamsSlice = (state: RootState) => state.params;

View File

@@ -2,7 +2,8 @@ import { objectEquals } from '@observ33r/object-equals';
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import type { PersistConfig, RootState } from 'app/store/store';
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { clamp } from 'es-toolkit/compat';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import type { FLUXReduxImageInfluence, RefImagesState } from 'features/controlLayers/store/types';
@@ -18,7 +19,7 @@ import { assert } from 'tsafe';
import type { PartialDeep } from 'type-fest';
import type { CLIPVisionModelV2, IPMethodV2, RefImageState } from './types';
import { getInitialRefImagesState, isFLUXReduxConfig, isIPAdapterConfig } from './types';
import { getInitialRefImagesState, isFLUXReduxConfig, isIPAdapterConfig, zRefImagesState } from './types';
import {
getReferenceImageState,
imageDTOToImageWithDims,
@@ -36,7 +37,7 @@ type PayloadActionWithId<T = void> = T extends void
} & T
>;
export const refImagesSlice = createSlice({
const slice = createSlice({
name: 'refImages',
initialState: getInitialRefImagesState(),
reducers: {
@@ -263,18 +264,16 @@ export const {
refImageFLUXReduxImageInfluenceChanged,
refImageIsEnabledToggled,
refImagesRecalled,
} = refImagesSlice.actions;
} = slice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrate = (state: any): any => {
return state;
};
export const refImagesPersistConfig: PersistConfig<RefImagesState> = {
name: refImagesSlice.name,
initialState: getInitialRefImagesState(),
migrate,
persistDenylist: ['selectedEntityId', 'isPanelOpen'],
export const refImagesSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zRefImagesState,
getInitialState: getInitialRefImagesState,
persistConfig: {
migrate: (state) => zRefImagesState.parse(state),
persistDenylist: ['selectedEntityId', 'isPanelOpen'],
},
};
export const selectRefImagesSlice = (state: RootState) => state.refImages;

View File

@@ -1,9 +1,7 @@
import { deepClone } from 'common/util/deepClone';
import type { CanvasEntityAdapter } from 'features/controlLayers/konva/CanvasEntity/types';
import { fetchModelConfigByIdentifier } from 'features/metadata/util/modelFetchingHelpers';
import type { ProgressImage } from 'features/nodes/types/common';
import { zMainModelBase, zModelIdentifierField } from 'features/nodes/types/common';
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
import {
zParameterCanvasCoherenceMode,
zParameterCFGRescaleMultiplier,
@@ -29,33 +27,17 @@ import {
zParameterT5EncoderModel,
zParameterVAEModel,
} from 'features/parameters/types/parameterSchemas';
import { getImageDTOSafe } from 'services/api/endpoints/images';
import type { JsonObject } from 'type-fest';
import { z } from 'zod';
const zId = z.string().min(1);
const zName = z.string().min(1).nullable();
const zServerValidatedModelIdentifierField = zModelIdentifierField.refine(async (modelIdentifier) => {
try {
await fetchModelConfigByIdentifier(modelIdentifier);
return true;
} catch {
return false;
}
export const zImageWithDims = z.object({
image_name: z.string(),
width: z.number().int().positive(),
height: z.number().int().positive(),
});
const zImageWithDims = z
.object({
image_name: z.string(),
width: z.number().int().positive(),
height: z.number().int().positive(),
})
.refine(async (v) => {
const { image_name } = v;
const imageDTO = await getImageDTOSafe(image_name, { forceRefetch: true });
return imageDTO !== null;
});
export type ImageWithDims = z.infer<typeof zImageWithDims>;
const zImageWithDimsDataURL = z.object({
@@ -253,7 +235,7 @@ export type CanvasObjectState = z.infer<typeof zCanvasObjectState>;
const zIPAdapterConfig = z.object({
type: z.literal('ip_adapter'),
image: zImageWithDims.nullable(),
model: zServerValidatedModelIdentifierField.nullable(),
model: zModelIdentifierField.nullable(),
weight: z.number().gte(-1).lte(2),
beginEndStepPct: zBeginEndStepPct,
method: zIPMethodV2,
@@ -268,7 +250,7 @@ export type FLUXReduxImageInfluence = z.infer<typeof zFLUXReduxImageInfluence>;
const zFLUXReduxConfig = z.object({
type: z.literal('flux_redux'),
image: zImageWithDims.nullable(),
model: zServerValidatedModelIdentifierField.nullable(),
model: zModelIdentifierField.nullable(),
imageInfluence: zFLUXReduxImageInfluence.default('highest'),
});
export type FLUXReduxConfig = z.infer<typeof zFLUXReduxConfig>;
@@ -281,14 +263,14 @@ const zChatGPT4oReferenceImageConfig = z.object({
* But we use a model drop down to switch between different ref image types, so there needs to be a model here else
* there will be no way to switch between ref image types.
*/
model: zServerValidatedModelIdentifierField.nullable(),
model: zModelIdentifierField.nullable(),
});
export type ChatGPT4oReferenceImageConfig = z.infer<typeof zChatGPT4oReferenceImageConfig>;
const zFluxKontextReferenceImageConfig = z.object({
type: z.literal('flux_kontext_reference_image'),
image: zImageWithDims.nullable(),
model: zServerValidatedModelIdentifierField.nullable(),
model: zModelIdentifierField.nullable(),
});
export type FluxKontextReferenceImageConfig = z.infer<typeof zFluxKontextReferenceImageConfig>;
@@ -360,7 +342,7 @@ export type CanvasInpaintMaskState = z.infer<typeof zCanvasInpaintMaskState>;
const zControlNetConfig = z.object({
type: z.literal('controlnet'),
model: zServerValidatedModelIdentifierField.nullable(),
model: zModelIdentifierField.nullable(),
weight: z.number().gte(-1).lte(2),
beginEndStepPct: zBeginEndStepPct,
controlMode: zControlModeV2,
@@ -369,7 +351,7 @@ export type ControlNetConfig = z.infer<typeof zControlNetConfig>;
const zT2IAdapterConfig = z.object({
type: z.literal('t2i_adapter'),
model: zServerValidatedModelIdentifierField.nullable(),
model: zModelIdentifierField.nullable(),
weight: z.number().gte(-1).lte(2),
beginEndStepPct: zBeginEndStepPct,
});
@@ -378,7 +360,7 @@ export type T2IAdapterConfig = z.infer<typeof zT2IAdapterConfig>;
const zControlLoRAConfig = z.object({
type: z.literal('control_lora'),
weight: z.number().gte(-1).lte(2),
model: zServerValidatedModelIdentifierField.nullable(),
model: zModelIdentifierField.nullable(),
});
export type ControlLoRAConfig = z.infer<typeof zControlLoRAConfig>;
@@ -424,12 +406,13 @@ export const zCanvasEntityIdentifer = z.object({
});
export type CanvasEntityIdentifier<T extends CanvasEntityType = CanvasEntityType> = { id: string; type: T };
export type LoRA = {
id: string;
isEnabled: boolean;
model: ParameterLoRAModel;
weight: number;
};
export const zLoRA = z.object({
id: z.string(),
isEnabled: z.boolean(),
model: zModelIdentifierField,
weight: z.number().gte(-1).lte(2),
});
export type LoRA = z.infer<typeof zLoRA>;
export type EphemeralProgressImage = { sessionId: string; image: ProgressImage };
@@ -522,62 +505,108 @@ const zDimensionsState = z.object({
aspectRatio: zAspectRatioConfig,
});
const zParamsState = z.object({
maskBlur: z.number().default(16),
maskBlurMethod: zParameterMaskBlurMethod.default('box'),
canvasCoherenceMode: zParameterCanvasCoherenceMode.default('Gaussian Blur'),
canvasCoherenceMinDenoise: zParameterStrength.default(0),
canvasCoherenceEdgeSize: z.number().default(16),
infillMethod: z.string().default('lama'),
infillTileSize: z.number().default(32),
infillPatchmatchDownscaleSize: z.number().default(1),
infillColorValue: zRgbaColor.default({ r: 0, g: 0, b: 0, a: 1 }),
cfgScale: zParameterCFGScale.default(7.5),
cfgRescaleMultiplier: zParameterCFGRescaleMultiplier.default(0),
guidance: zParameterGuidance.default(4),
img2imgStrength: zParameterStrength.default(0.75),
optimizedDenoisingEnabled: z.boolean().default(true),
iterations: z.number().default(1),
scheduler: zParameterScheduler.default('dpmpp_3m_k'),
upscaleScheduler: zParameterScheduler.default('kdpm_2'),
upscaleCfgScale: zParameterCFGScale.default(2),
seed: zParameterSeed.default(0),
shouldRandomizeSeed: z.boolean().default(true),
steps: zParameterSteps.default(30),
model: zParameterModel.nullable().default(null),
vae: zParameterVAEModel.nullable().default(null),
vaePrecision: zParameterPrecision.default('fp32'),
fluxVAE: zParameterVAEModel.nullable().default(null),
seamlessXAxis: z.boolean().default(false),
seamlessYAxis: z.boolean().default(false),
clipSkip: z.number().default(0),
shouldUseCpuNoise: z.boolean().default(true),
positivePrompt: zParameterPositivePrompt.default(''),
// Negative prompt may be disabled, in which case it will be null
negativePrompt: zParameterNegativePrompt.default(null),
positivePrompt2: zParameterPositiveStylePromptSDXL.default(''),
negativePrompt2: zParameterNegativeStylePromptSDXL.default(''),
shouldConcatPrompts: z.boolean().default(true),
refinerModel: zParameterSDXLRefinerModel.nullable().default(null),
refinerSteps: z.number().default(20),
refinerCFGScale: z.number().default(7.5),
refinerScheduler: zParameterScheduler.default('euler'),
refinerPositiveAestheticScore: z.number().default(6),
refinerNegativeAestheticScore: z.number().default(2.5),
refinerStart: z.number().default(0.8),
t5EncoderModel: zParameterT5EncoderModel.nullable().default(null),
clipEmbedModel: zParameterCLIPEmbedModel.nullable().default(null),
clipLEmbedModel: zParameterCLIPLEmbedModel.nullable().default(null),
clipGEmbedModel: zParameterCLIPGEmbedModel.nullable().default(null),
controlLora: zParameterControlLoRAModel.nullable().default(null),
dimensions: zDimensionsState.default({
rect: { x: 0, y: 0, width: 512, height: 512 },
aspectRatio: DEFAULT_ASPECT_RATIO_CONFIG,
}),
export const zParamsState = z.object({
maskBlur: z.number(),
maskBlurMethod: zParameterMaskBlurMethod,
canvasCoherenceMode: zParameterCanvasCoherenceMode,
canvasCoherenceMinDenoise: zParameterStrength,
canvasCoherenceEdgeSize: z.number(),
infillMethod: z.string(),
infillTileSize: z.number(),
infillPatchmatchDownscaleSize: z.number(),
infillColorValue: zRgbaColor,
cfgScale: zParameterCFGScale,
cfgRescaleMultiplier: zParameterCFGRescaleMultiplier,
guidance: zParameterGuidance,
img2imgStrength: zParameterStrength,
optimizedDenoisingEnabled: z.boolean(),
iterations: z.number(),
scheduler: zParameterScheduler,
upscaleScheduler: zParameterScheduler,
upscaleCfgScale: zParameterCFGScale,
seed: zParameterSeed,
shouldRandomizeSeed: z.boolean(),
steps: zParameterSteps,
model: zParameterModel.nullable(),
vae: zParameterVAEModel.nullable(),
vaePrecision: zParameterPrecision,
fluxVAE: zParameterVAEModel.nullable(),
seamlessXAxis: z.boolean(),
seamlessYAxis: z.boolean(),
clipSkip: z.number(),
shouldUseCpuNoise: z.boolean(),
positivePrompt: zParameterPositivePrompt,
negativePrompt: zParameterNegativePrompt,
positivePrompt2: zParameterPositiveStylePromptSDXL,
negativePrompt2: zParameterNegativeStylePromptSDXL,
shouldConcatPrompts: z.boolean(),
refinerModel: zParameterSDXLRefinerModel.nullable(),
refinerSteps: z.number(),
refinerCFGScale: z.number(),
refinerScheduler: zParameterScheduler,
refinerPositiveAestheticScore: z.number(),
refinerNegativeAestheticScore: z.number(),
refinerStart: z.number(),
t5EncoderModel: zParameterT5EncoderModel.nullable(),
clipEmbedModel: zParameterCLIPEmbedModel.nullable(),
clipLEmbedModel: zParameterCLIPLEmbedModel.nullable(),
clipGEmbedModel: zParameterCLIPGEmbedModel.nullable(),
controlLora: zParameterControlLoRAModel.nullable(),
dimensions: zDimensionsState,
});
export type ParamsState = z.infer<typeof zParamsState>;
const INITIAL_PARAMS_STATE = zParamsState.parse({});
export const getInitialParamsState = () => deepClone(INITIAL_PARAMS_STATE);
export const getInitialParamsState = (): ParamsState => ({
maskBlur: 16,
maskBlurMethod: 'box',
canvasCoherenceMode: 'Gaussian Blur',
canvasCoherenceMinDenoise: 0,
canvasCoherenceEdgeSize: 16,
infillMethod: 'lama',
infillTileSize: 32,
infillPatchmatchDownscaleSize: 1,
infillColorValue: { r: 0, g: 0, b: 0, a: 1 },
cfgScale: 7.5,
cfgRescaleMultiplier: 0,
guidance: 4,
img2imgStrength: 0.75,
optimizedDenoisingEnabled: true,
iterations: 1,
scheduler: 'dpmpp_3m_k',
upscaleScheduler: 'kdpm_2',
upscaleCfgScale: 2,
seed: 0,
shouldRandomizeSeed: true,
steps: 30,
model: null,
vae: null,
vaePrecision: 'fp32',
fluxVAE: null,
seamlessXAxis: false,
seamlessYAxis: false,
clipSkip: 0,
shouldUseCpuNoise: true,
positivePrompt: '',
negativePrompt: null,
positivePrompt2: '',
negativePrompt2: '',
shouldConcatPrompts: true,
refinerModel: null,
refinerSteps: 20,
refinerCFGScale: 7.5,
refinerScheduler: 'euler',
refinerPositiveAestheticScore: 6,
refinerNegativeAestheticScore: 2.5,
refinerStart: 0.8,
t5EncoderModel: null,
clipEmbedModel: null,
clipLEmbedModel: null,
clipGEmbedModel: null,
controlLora: null,
dimensions: {
rect: { x: 0, y: 0, width: 512, height: 512 },
aspectRatio: deepClone(DEFAULT_ASPECT_RATIO_CONFIG),
},
});
const zInpaintMasks = z.object({
isHidden: z.boolean(),
@@ -595,38 +624,45 @@ const zRegionalGuidance = z.object({
isHidden: z.boolean(),
entities: z.array(zCanvasRegionalGuidanceState),
});
const zCanvasState = z.object({
_version: z.literal(3).default(3),
selectedEntityIdentifier: zCanvasEntityIdentifer.nullable().default(null),
bookmarkedEntityIdentifier: zCanvasEntityIdentifer.nullable().default(null),
inpaintMasks: zInpaintMasks.default({ isHidden: false, entities: [] }),
rasterLayers: zRasterLayers.default({ isHidden: false, entities: [] }),
controlLayers: zControlLayers.default({ isHidden: false, entities: [] }),
regionalGuidance: zRegionalGuidance.default({ isHidden: false, entities: [] }),
bbox: zBboxState.default({
export const zCanvasState = z.object({
_version: z.literal(3),
selectedEntityIdentifier: zCanvasEntityIdentifer.nullable(),
bookmarkedEntityIdentifier: zCanvasEntityIdentifer.nullable(),
inpaintMasks: zInpaintMasks,
rasterLayers: zRasterLayers,
controlLayers: zControlLayers,
regionalGuidance: zRegionalGuidance,
bbox: zBboxState,
});
export type CanvasState = z.infer<typeof zCanvasState>;
export const getInitialCanvasState = (): CanvasState => ({
_version: 3,
selectedEntityIdentifier: null,
bookmarkedEntityIdentifier: null,
inpaintMasks: { isHidden: false, entities: [] },
rasterLayers: { isHidden: false, entities: [] },
controlLayers: { isHidden: false, entities: [] },
regionalGuidance: { isHidden: false, entities: [] },
bbox: {
rect: { x: 0, y: 0, width: 512, height: 512 },
aspectRatio: DEFAULT_ASPECT_RATIO_CONFIG,
aspectRatio: deepClone(DEFAULT_ASPECT_RATIO_CONFIG),
scaleMethod: 'auto',
scaledSize: { width: 512, height: 512 },
modelBase: 'sd-1',
}),
},
});
export type CanvasState = z.infer<typeof zCanvasState>;
const zRefImagesState = z.object({
selectedEntityId: z.string().nullable().default(null),
isPanelOpen: z.boolean().default(false),
entities: z.array(zRefImageState).default(() => []),
export const zRefImagesState = z.object({
selectedEntityId: z.string().nullable(),
isPanelOpen: z.boolean(),
entities: z.array(zRefImageState),
});
export type RefImagesState = z.infer<typeof zRefImagesState>;
const INITIAL_REF_IMAGES_STATE = zRefImagesState.parse({});
export const getInitialRefImagesState = () => deepClone(INITIAL_REF_IMAGES_STATE);
/**
* Gets a fresh canvas initial state with no references in memory to existing objects.
*/
const CANVAS_INITIAL_STATE = zCanvasState.parse({});
export const getInitialCanvasState = () => deepClone(CANVAS_INITIAL_STATE);
export const getInitialRefImagesState = (): RefImagesState => ({
selectedEntityId: null,
isPanelOpen: false,
entities: [],
});
export const zCanvasReferenceImageState_OLD = zCanvasEntityBase.extend({
type: z.literal('reference_image'),

View File

@@ -1,25 +1,29 @@
import type { PayloadAction, Selector } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { buildZodTypeGuard } from 'common/util/zodUtils';
import { isPlainObject } from 'es-toolkit';
import { assert } from 'tsafe';
import { z } from 'zod';
const zSeedBehaviour = z.enum(['PER_ITERATION', 'PER_PROMPT']);
export const isSeedBehaviour = buildZodTypeGuard(zSeedBehaviour);
export type SeedBehaviour = z.infer<typeof zSeedBehaviour>;
export interface DynamicPromptsState {
_version: 1;
maxPrompts: number;
combinatorial: boolean;
prompts: string[];
parsingError: string | undefined | null;
isError: boolean;
isLoading: boolean;
seedBehaviour: SeedBehaviour;
}
const zDynamicPromptsState = z.object({
_version: z.literal(1),
maxPrompts: z.number().int().min(1).max(1000),
combinatorial: z.boolean(),
prompts: z.array(z.string()),
parsingError: z.string().nullish(),
isError: z.boolean(),
isLoading: z.boolean(),
seedBehaviour: zSeedBehaviour,
});
export type DynamicPromptsState = z.infer<typeof zDynamicPromptsState>;
const initialDynamicPromptsState: DynamicPromptsState = {
const getInitialState = (): DynamicPromptsState => ({
_version: 1,
maxPrompts: 100,
combinatorial: true,
@@ -28,11 +32,11 @@ const initialDynamicPromptsState: DynamicPromptsState = {
isError: false,
isLoading: false,
seedBehaviour: 'PER_ITERATION',
};
});
export const dynamicPromptsSlice = createSlice({
const slice = createSlice({
name: 'dynamicPrompts',
initialState: initialDynamicPromptsState,
initialState: getInitialState(),
reducers: {
maxPromptsChanged: (state, action: PayloadAction<number>) => {
state.maxPrompts = action.payload;
@@ -63,21 +67,22 @@ export const {
isErrorChanged,
isLoadingChanged,
seedBehaviourChanged,
} = dynamicPromptsSlice.actions;
} = slice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrateDynamicPromptsState = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
return state;
};
export const dynamicPromptsPersistConfig: PersistConfig<DynamicPromptsState> = {
name: dynamicPromptsSlice.name,
initialState: initialDynamicPromptsState,
migrate: migrateDynamicPromptsState,
persistDenylist: ['prompts'],
export const dynamicPromptsSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zDynamicPromptsState,
getInitialState,
persistConfig: {
migrate: (state) => {
assert(isPlainObject(state));
if (!('_version' in state)) {
state._version = 1;
}
return zDynamicPromptsState.parse(state);
},
persistDenylist: ['prompts', 'parsingError', 'isError', 'isLoading'],
},
};
export const selectDynamicPromptsSlice = (state: RootState) => state.dynamicPrompts;

View File

@@ -21,7 +21,14 @@ export const ImageMenuItemNewCanvasFromImageSubMenu = memo(() => {
const onClickNewCanvasWithRasterLayerFromImage = useCallback(async () => {
const { dispatch, getState } = store;
await navigationApi.focusPanel('canvas', WORKSPACE_PANEL_ID);
await newCanvasFromImage({ imageDTO, withResize: false, type: 'raster_layer', dispatch, getState });
await newCanvasFromImage({
imageDTO,
withResize: false,
withInpaintMask: true,
type: 'raster_layer',
dispatch,
getState,
});
toast({
id: 'SENT_TO_CANVAS',
title: t('toast.sentToCanvas'),
@@ -32,7 +39,14 @@ export const ImageMenuItemNewCanvasFromImageSubMenu = memo(() => {
const onClickNewCanvasWithControlLayerFromImage = useCallback(async () => {
const { dispatch, getState } = store;
await navigationApi.focusPanel('canvas', WORKSPACE_PANEL_ID);
await newCanvasFromImage({ imageDTO, withResize: false, type: 'control_layer', dispatch, getState });
await newCanvasFromImage({
imageDTO,
withResize: false,
withInpaintMask: true,
type: 'control_layer',
dispatch,
getState,
});
toast({
id: 'SENT_TO_CANVAS',
title: t('toast.sentToCanvas'),
@@ -43,7 +57,14 @@ export const ImageMenuItemNewCanvasFromImageSubMenu = memo(() => {
const onClickNewCanvasWithRasterLayerFromImageWithResize = useCallback(async () => {
const { dispatch, getState } = store;
await navigationApi.focusPanel('canvas', WORKSPACE_PANEL_ID);
await newCanvasFromImage({ imageDTO, withResize: true, type: 'raster_layer', dispatch, getState });
await newCanvasFromImage({
imageDTO,
withResize: true,
withInpaintMask: true,
type: 'raster_layer',
dispatch,
getState,
});
toast({
id: 'SENT_TO_CANVAS',
title: t('toast.sentToCanvas'),
@@ -54,7 +75,14 @@ export const ImageMenuItemNewCanvasFromImageSubMenu = memo(() => {
const onClickNewCanvasWithControlLayerFromImageWithResize = useCallback(async () => {
const { dispatch, getState } = store;
await navigationApi.focusPanel('canvas', WORKSPACE_PANEL_ID);
await newCanvasFromImage({ imageDTO, withResize: true, type: 'control_layer', dispatch, getState });
await newCanvasFromImage({
imageDTO,
withResize: true,
withInpaintMask: true,
type: 'control_layer',
dispatch,
getState,
});
toast({
id: 'SENT_TO_CANVAS',
title: t('toast.sentToCanvas'),

View File

@@ -1,5 +1,6 @@
import { MenuItem } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { imageDTOToImageWithDims } from 'features/controlLayers/store/util';
import { useImageDTOContext } from 'features/gallery/contexts/ImageDTOContext';
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
import { toast } from 'features/toast/toast';
@@ -14,7 +15,7 @@ export const ImageMenuItemSendToUpscale = memo(() => {
const imageDTO = useImageDTOContext();
const handleSendToCanvas = useCallback(() => {
dispatch(upscaleInitialImageChanged(imageDTO));
dispatch(upscaleInitialImageChanged(imageDTOToImageWithDims(imageDTO)));
navigationApi.switchToTab('upscaling');
toast({
id: 'SENT_TO_CANVAS',

View File

@@ -1,13 +1,23 @@
import { objectEquals } from '@observ33r/object-equals';
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { isPlainObject } from 'es-toolkit';
import { uniq } from 'es-toolkit/compat';
import type { BoardRecordOrderBy } from 'services/api/types';
import { assert } from 'tsafe';
import type { BoardId, ComparisonMode, GalleryState, GalleryView, OrderDir } from './types';
import {
type BoardId,
type ComparisonMode,
type GalleryState,
type GalleryView,
type OrderDir,
zGalleryState,
} from './types';
const initialGalleryState: GalleryState = {
const getInitialState = (): GalleryState => ({
selection: [],
shouldAutoSwitch: true,
autoAssignBoardOnClick: true,
@@ -26,11 +36,11 @@ const initialGalleryState: GalleryState = {
shouldShowArchivedBoards: false,
boardsListOrderBy: 'created_at',
boardsListOrderDir: 'DESC',
};
});
export const gallerySlice = createSlice({
const slice = createSlice({
name: 'gallery',
initialState: initialGalleryState,
initialState: getInitialState(),
reducers: {
imageSelected: (state, action: PayloadAction<string | null>) => {
// Let's be efficient here and not update the selection unless it has actually changed. This helps to prevent
@@ -187,21 +197,22 @@ export const {
searchTermChanged,
boardsListOrderByChanged,
boardsListOrderDirChanged,
} = gallerySlice.actions;
} = slice.actions;
export const selectGallerySlice = (state: RootState) => state.gallery;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrateGalleryState = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
return state;
};
export const galleryPersistConfig: PersistConfig<GalleryState> = {
name: gallerySlice.name,
initialState: initialGalleryState,
migrate: migrateGalleryState,
persistDenylist: ['selection', 'selectedBoardId', 'galleryView', 'imageToCompare'],
export const gallerySliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zGalleryState,
getInitialState,
persistConfig: {
migrate: (state) => {
assert(isPlainObject(state));
if (!('_version' in state)) {
state._version = 1;
}
return zGalleryState.parse(state);
},
persistDenylist: ['selection', 'selectedBoardId', 'galleryView', 'imageToCompare'],
},
};

View File

@@ -0,0 +1,13 @@
import type { S } from 'services/api/types';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
import { describe, test } from 'vitest';
import type { BoardRecordOrderBy } from './types';
describe('Gallery Types', () => {
// Ensure zod types match OpenAPI types
test('BoardRecordOrderBy', () => {
assert<Equals<BoardRecordOrderBy, S['BoardRecordOrderBy']>>();
});
});

View File

@@ -1,31 +1,41 @@
import type { BoardRecordOrderBy, ImageCategory } from 'services/api/types';
import type { ImageCategory } from 'services/api/types';
import z from 'zod';
const zGalleryView = z.enum(['images', 'assets']);
export type GalleryView = z.infer<typeof zGalleryView>;
const zBoardId = z.union([z.literal('none'), z.intersection(z.string(), z.record(z.never(), z.never()))]);
export type BoardId = z.infer<typeof zBoardId>;
const zComparisonMode = z.enum(['slider', 'side-by-side', 'hover']);
export type ComparisonMode = z.infer<typeof zComparisonMode>;
const zComparisonFit = z.enum(['contain', 'fill']);
export type ComparisonFit = z.infer<typeof zComparisonFit>;
const zOrderDir = z.enum(['ASC', 'DESC']);
export type OrderDir = z.infer<typeof zOrderDir>;
const zBoardRecordOrderBy = z.enum(['created_at', 'board_name']);
export type BoardRecordOrderBy = z.infer<typeof zBoardRecordOrderBy>;
export const IMAGE_CATEGORIES: ImageCategory[] = ['general'];
export const ASSETS_CATEGORIES: ImageCategory[] = ['control', 'mask', 'user', 'other'];
export type GalleryView = 'images' | 'assets';
export type BoardId = 'none' | (string & Record<never, never>);
export type ComparisonMode = 'slider' | 'side-by-side' | 'hover';
export type ComparisonFit = 'contain' | 'fill';
export type OrderDir = 'ASC' | 'DESC';
export const zGalleryState = z.object({
selection: z.array(z.string()),
shouldAutoSwitch: z.boolean(),
autoAssignBoardOnClick: z.boolean(),
autoAddBoardId: zBoardId,
galleryImageMinimumWidth: z.number(),
selectedBoardId: zBoardId,
galleryView: zGalleryView,
boardSearchText: z.string(),
starredFirst: z.boolean(),
orderDir: zOrderDir,
searchTerm: z.string(),
alwaysShowImageSizeBadge: z.boolean(),
imageToCompare: z.string().nullable(),
comparisonMode: zComparisonMode,
comparisonFit: zComparisonFit,
shouldShowArchivedBoards: z.boolean(),
boardsListOrderBy: zBoardRecordOrderBy,
boardsListOrderDir: zOrderDir,
});
export type GalleryState = {
selection: string[];
shouldAutoSwitch: boolean;
autoAssignBoardOnClick: boolean;
autoAddBoardId: BoardId;
galleryImageMinimumWidth: number;
selectedBoardId: BoardId;
galleryView: GalleryView;
boardSearchText: string;
starredFirst: boolean;
orderDir: OrderDir;
searchTerm: string;
alwaysShowImageSizeBadge: boolean;
imageToCompare: string | null;
comparisonMode: ComparisonMode;
comparisonFit: ComparisonFit;
shouldShowArchivedBoards: boolean;
boardsListOrderBy: BoardRecordOrderBy;
boardsListOrderDir: OrderDir;
};
export type GalleryState = z.infer<typeof zGalleryState>;

View File

@@ -58,7 +58,7 @@ export const setRegionalGuidanceReferenceImage = (arg: {
export const setUpscaleInitialImage = (arg: { imageDTO: ImageDTO; dispatch: AppDispatch }) => {
const { imageDTO, dispatch } = arg;
dispatch(upscaleInitialImageChanged(imageDTO));
dispatch(upscaleInitialImageChanged(imageDTOToImageWithDims(imageDTO)));
};
export const setNodeImageFieldImage = (arg: {

View File

@@ -89,6 +89,7 @@ import { t } from 'i18next';
import type { ComponentType } from 'react';
import { useCallback, useEffect, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { imagesApi } from 'services/api/endpoints/images';
import { modelsApi } from 'services/api/endpoints/models';
import type { AnyModelConfig, ModelType } from 'services/api/types';
import { assert } from 'tsafe';
@@ -787,11 +788,55 @@ const LoRAs: CollectionMetadataHandler<LoRA[]> = {
const CanvasLayers: SingleMetadataHandler<CanvasMetadata> = {
[SingleMetadataKey]: true,
type: 'CanvasLayers',
parse: async (metadata) => {
parse: async (metadata, store) => {
const raw = getProperty(metadata, 'canvas_v2_metadata');
// This validator fetches all referenced images. If any do not exist, validation fails. The logic for this is in
// the zImageWithDims schema.
const parsed = await zCanvasMetadata.parseAsync(raw);
for (const entity of parsed.controlLayers) {
if (entity.controlAdapter.model) {
await throwIfModelDoesNotExist(entity.controlAdapter.model.key, store);
}
for (const object of entity.objects) {
if (object.type === 'image' && 'image_name' in object.image) {
await throwIfImageDoesNotExist(object.image.image_name, store);
}
}
}
for (const entity of parsed.inpaintMasks) {
for (const object of entity.objects) {
if (object.type === 'image' && 'image_name' in object.image) {
await throwIfImageDoesNotExist(object.image.image_name, store);
}
}
}
for (const entity of parsed.rasterLayers) {
for (const object of entity.objects) {
if (object.type === 'image' && 'image_name' in object.image) {
await throwIfImageDoesNotExist(object.image.image_name, store);
}
}
}
for (const entity of parsed.regionalGuidance) {
for (const object of entity.objects) {
if (object.type === 'image' && 'image_name' in object.image) {
await throwIfImageDoesNotExist(object.image.image_name, store);
}
}
for (const refImage of entity.referenceImages) {
if (refImage.config.image) {
await throwIfImageDoesNotExist(refImage.config.image.image_name, store);
}
if (refImage.config.model) {
await throwIfModelDoesNotExist(refImage.config.model.key, store);
}
}
}
return Promise.resolve(parsed);
},
recall: (value, store) => {
@@ -824,27 +869,39 @@ const CanvasLayers: SingleMetadataHandler<CanvasMetadata> = {
const RefImages: CollectionMetadataHandler<RefImageState[]> = {
[CollectionMetadataKey]: true,
type: 'RefImages',
parse: async (metadata) => {
parse: async (metadata, store) => {
let parsed: RefImageState[] | null = null;
try {
// First attempt to parse from the v6 slot
const raw = getProperty(metadata, 'ref_images');
// This validator fetches all referenced images. If any do not exist, validation fails. The logic for this is in
// the zImageWithDims schema.
const parsed = await z.array(zRefImageState).parseAsync(raw);
return Promise.resolve(parsed);
parsed = z.array(zRefImageState).parse(raw);
} catch {
// Fall back to extracting from canvas metadata]
const raw = getProperty(metadata, 'canvas_v2_metadata.referenceImages.entities');
// This validator fetches all referenced images. If any do not exist, validation fails. The logic for this is in
// the zImageWithDims schema.
const oldParsed = await z.array(zCanvasReferenceImageState_OLD).parseAsync(raw);
const parsed: RefImageState[] = oldParsed.map(({ id, ipAdapter, isEnabled }) => ({
parsed = oldParsed.map(({ id, ipAdapter, isEnabled }) => ({
id,
config: ipAdapter,
isEnabled,
}));
return parsed;
}
if (!parsed) {
throw new Error('No valid reference images found in metadata');
}
for (const refImage of parsed) {
if (refImage.config.image) {
await throwIfImageDoesNotExist(refImage.config.image.image_name, store);
}
if (refImage.config.model) {
await throwIfModelDoesNotExist(refImage.config.model.key, store);
}
}
return parsed;
},
recall: (value, store) => {
const entities = value.map((data) => ({ ...data, id: getPrefixedId('reference_image') }));
@@ -1241,3 +1298,19 @@ const isCompatibleWithMainModel = (candidate: ModelIdentifierField, store: AppSt
}
return candidate.base === base;
};
const throwIfImageDoesNotExist = async (name: string, store: AppStore): Promise<void> => {
try {
await store.dispatch(imagesApi.endpoints.getImageDTO.initiate(name, { subscribe: false })).unwrap();
} catch {
throw new Error(`Image with name ${name} does not exist`);
}
};
const throwIfModelDoesNotExist = async (key: string, store: AppStore): Promise<void> => {
try {
await store.dispatch(modelsApi.endpoints.getModelConfig.initiate(key, { subscribe: false }));
} catch {
throw new Error(`Model with key ${key} does not exist`);
}
};

View File

@@ -1,7 +1,6 @@
import { getStore } from 'app/store/nanostores/store';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import { modelsApi } from 'services/api/endpoints/models';
import type { AnyModelConfig, BaseModelType, ModelType } from 'services/api/types';
import type { AnyModelConfig } from 'services/api/types';
/**
* Raised when a model config is unable to be fetched.
@@ -47,45 +46,6 @@ const fetchModelConfig = async (key: string): Promise<AnyModelConfig> => {
}
};
/**
* Fetches the model config for a given model name, base model, and model type. This provides backwards compatibility
* for MM1 model identifiers.
* @param name The model name.
* @param base The base model.
* @param type The model type.
* @returns A promise that resolves to the model config.
* @throws {ModelConfigNotFoundError} If the model config is unable to be fetched.
*/
const fetchModelConfigByAttrs = async (name: string, base: BaseModelType, type: ModelType): Promise<AnyModelConfig> => {
const { dispatch } = getStore();
try {
const req = dispatch(
modelsApi.endpoints.getModelConfigByAttrs.initiate({ name, base, type }, { subscribe: false })
);
return await req.unwrap();
} catch {
throw new ModelConfigNotFoundError(`Unable to retrieve model config for name/base/type ${name}/${base}/${type}`);
}
};
/**
* Fetches the model config given an identifier. First attempts to fetch by key, then falls back to fetching by attrs.
* @param identifier The model identifier.
* @returns A promise that resolves to the model config.
* @throws {ModelConfigNotFoundError} If the model config is unable to be fetched.
*/
export const fetchModelConfigByIdentifier = async (identifier: ModelIdentifierField): Promise<AnyModelConfig> => {
try {
return await fetchModelConfig(identifier.key);
} catch {
try {
return await fetchModelConfigByAttrs(identifier.name, identifier.base, identifier.type);
} catch {
throw new ModelConfigNotFoundError(`Unable to retrieve model config for identifier ${identifier}`);
}
}
};
/**
* Fetches the model config for a given model key and type, and ensures that the model config is of a specific type.
* @param key The model key.

View File

@@ -1,21 +1,28 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import type { ModelType } from 'services/api/types';
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { isPlainObject } from 'es-toolkit';
import { zModelType } from 'features/nodes/types/common';
import { assert } from 'tsafe';
import z from 'zod';
export type FilterableModelType = Exclude<ModelType, 'onnx'> | 'refiner';
const zFilterableModelType = zModelType.exclude(['onnx']).or(z.literal('refiner'));
export type FilterableModelType = z.infer<typeof zFilterableModelType>;
type ModelManagerState = {
_version: 1;
selectedModelKey: string | null;
selectedModelMode: 'edit' | 'view';
searchTerm: string;
filteredModelType: FilterableModelType | null;
scanPath: string | undefined;
shouldInstallInPlace: boolean;
};
const zModelManagerState = z.object({
_version: z.literal(1),
selectedModelKey: z.string().nullable(),
selectedModelMode: z.enum(['edit', 'view']),
searchTerm: z.string(),
filteredModelType: zFilterableModelType.nullable(),
scanPath: z.string().optional(),
shouldInstallInPlace: z.boolean(),
});
const initialModelManagerState: ModelManagerState = {
type ModelManagerState = z.infer<typeof zModelManagerState>;
const getInitialState = (): ModelManagerState => ({
_version: 1,
selectedModelKey: null,
selectedModelMode: 'view',
@@ -23,11 +30,11 @@ const initialModelManagerState: ModelManagerState = {
searchTerm: '',
scanPath: undefined,
shouldInstallInPlace: true,
};
});
export const modelManagerV2Slice = createSlice({
const slice = createSlice({
name: 'modelmanagerV2',
initialState: initialModelManagerState,
initialState: getInitialState(),
reducers: {
setSelectedModelKey: (state, action: PayloadAction<string | null>) => {
state.selectedModelMode = 'view';
@@ -58,21 +65,22 @@ export const {
setSelectedModelMode,
setScanPath,
shouldInstallInPlaceChanged,
} = modelManagerV2Slice.actions;
} = slice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrateModelManagerState = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
return state;
};
export const modelManagerV2PersistConfig: PersistConfig<ModelManagerState> = {
name: modelManagerV2Slice.name,
initialState: initialModelManagerState,
migrate: migrateModelManagerState,
persistDenylist: ['selectedModelKey', 'selectedModelMode', 'filteredModelType', 'searchTerm'],
export const modelManagerSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zModelManagerState,
getInitialState,
persistConfig: {
migrate: (state) => {
assert(isPlainObject(state));
if (!('_version' in state)) {
state._version = 1;
}
return zModelManagerState.parse(state);
},
persistDenylist: ['selectedModelKey', 'selectedModelMode', 'filteredModelType', 'searchTerm'],
},
};
export const selectModelManagerV2Slice = (state: RootState) => state.modelmanagerV2;

View File

@@ -14,7 +14,13 @@ import type {
ReactFlowProps,
ReactFlowState,
} from '@xyflow/react';
import { Background, ReactFlow, useStore as useReactFlowStore, useUpdateNodeInternals } from '@xyflow/react';
import {
Background,
ReactFlow,
SelectionMode,
useStore as useReactFlowStore,
useUpdateNodeInternals,
} from '@xyflow/react';
import { useAppDispatch, useAppSelector, useAppStore } from 'app/store/storeHooks';
import { useFocusRegion, useIsRegionFocused } from 'common/hooks/focus';
import { $isSelectingOutputNode, $outputNodeId } from 'features/nodes/components/sidePanel/workflow/publish';
@@ -256,7 +262,7 @@ export const Flow = memo(() => {
style={flowStyles}
onPaneClick={handlePaneClick}
deleteKeyCode={null}
selectionMode={selectionMode}
selectionMode={selectionMode === 'full' ? SelectionMode.Full : SelectionMode.Partial}
elevateEdgesOnSelect
nodeDragThreshold={1}
noDragClassName={NO_DRAG_CLASS}

View File

@@ -11,14 +11,15 @@ import type {
XYPosition,
} from '@xyflow/react';
import { applyEdgeChanges, applyNodeChanges, getConnectedEdges, getIncomers, getOutgoers } from '@xyflow/react';
import type { PersistConfig } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { deepClone } from 'common/util/deepClone';
import { isPlainObject } from 'es-toolkit';
import {
addElement,
removeElement,
reparentElement,
} from 'features/nodes/components/sidePanel/builder/form-manipulation';
import type { NodesState } from 'features/nodes/store/types';
import { type NodesState, zNodesState } from 'features/nodes/store/types';
import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants';
import type {
BoardFieldValue,
@@ -127,6 +128,7 @@ import {
import { atom, computed } from 'nanostores';
import type { MouseEvent } from 'react';
import type { UndoableOptions } from 'redux-undo';
import { assert } from 'tsafe';
import type { z } from 'zod';
import type { PendingConnection, Templates } from './types';
@@ -151,11 +153,11 @@ export const getInitialWorkflow = (): Omit<NodesState, 'mode' | 'formFieldInitia
};
};
const initialState: NodesState = {
const getInitialState = (): NodesState => ({
_version: 1,
formFieldInitialValues: {},
...getInitialWorkflow(),
};
});
type FieldValueAction<T extends FieldValue> = PayloadAction<{
nodeId: string;
@@ -208,9 +210,9 @@ const fieldValueReducer = <T extends FieldValue>(
field.value = result.data;
};
export const nodesSlice = createSlice({
const slice = createSlice({
name: 'nodes',
initialState: initialState,
initialState: getInitialState(),
reducers: {
nodesChanged: (state, action: PayloadAction<NodeChange<AnyNode>[]>) => {
// In v12.7.0, @xyflow/react added a `domAttributes` property to the node data. One DOM attribute is
@@ -588,7 +590,7 @@ export const nodesSlice = createSlice({
}
node.data.notes = value;
},
nodeEditorReset: () => deepClone(initialState),
nodeEditorReset: () => getInitialState(),
workflowNameChanged: (state, action: PayloadAction<string>) => {
state.name = action.payload;
},
@@ -673,7 +675,7 @@ export const nodesSlice = createSlice({
const formFieldInitialValues = getFormFieldInitialValues(workflowExtra.form, nodes);
return {
...deepClone(initialState),
...getInitialState(),
...deepClone(workflowExtra),
formFieldInitialValues,
nodes: nodes.map((node) => ({ ...SHARED_NODE_PROPERTIES, ...node })),
@@ -758,7 +760,7 @@ export const {
workflowLoaded,
undo,
redo,
} = nodesSlice.actions;
} = slice.actions;
export const $cursorPos = atom<XYPosition | null>(null);
export const $templates = atom<Templates>({});
@@ -775,21 +777,6 @@ export const $lastEdgeUpdateMouseEvent = atom<MouseEvent | null>(null);
export const $viewport = atom<Viewport>({ x: 0, y: 0, zoom: 1 });
export const $addNodeCmdk = atom(false);
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrateNodesState = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
return state;
};
export const nodesPersistConfig: PersistConfig<NodesState> = {
name: nodesSlice.name,
initialState: initialState,
migrate: migrateNodesState,
persistDenylist: [],
};
type NodeSelectionAction = {
type: ReturnType<typeof nodesChanged>['type'];
payload: NodeSelectionChange[];
@@ -893,10 +880,10 @@ const isHighFrequencyWorkflowDetailsAction = isAnyOf(
// a note in a notes node, we don't want to create a new undo group for every keystroke.
const isHighFrequencyNodeScopedAction = isAnyOf(nodeLabelChanged, nodeNotesChanged, notesNodeValueChanged);
export const nodesUndoableConfig: UndoableOptions<NodesState, UnknownAction> = {
const reduxUndoOptions: UndoableOptions<NodesState, UnknownAction> = {
limit: 64,
undoType: nodesSlice.actions.undo.type,
redoType: nodesSlice.actions.redo.type,
undoType: slice.actions.undo.type,
redoType: slice.actions.redo.type,
groupBy: (action, _state, _history) => {
if (isHighFrequencyFieldChangeAction(action)) {
// Group by type, node id and field name
@@ -928,7 +915,7 @@ export const nodesUndoableConfig: UndoableOptions<NodesState, UnknownAction> = {
},
filter: (action, _state, _history) => {
// Ignore all actions from other slices
if (!action.type.startsWith(nodesSlice.name)) {
if (!action.type.startsWith(slice.name)) {
return false;
}
// Ignore actions that only select or deselect nodes and edges
@@ -943,6 +930,24 @@ export const nodesUndoableConfig: UndoableOptions<NodesState, UnknownAction> = {
},
};
export const nodesSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zNodesState,
getInitialState,
persistConfig: {
migrate: (state) => {
assert(isPlainObject(state));
if (!('_version' in state)) {
state._version = 1;
}
return zNodesState.parse(state);
},
},
undoableConfig: {
reduxUndoOptions,
},
};
// The form builder's initial values are based on the current values of the node fields in the workflow.
export const getFormFieldInitialValues = (form: BuilderForm, nodes: NodesState['nodes']) => {
const formFieldInitialValues: Record<string, StatefulFieldValue> = {};

View File

@@ -1,7 +1,8 @@
import type { HandleType } from '@xyflow/react';
import type { FieldInputTemplate, FieldOutputTemplate, StatefulFieldValue } from 'features/nodes/types/field';
import type { AnyEdge, AnyNode, InvocationTemplate, NodeExecutionState } from 'features/nodes/types/invocation';
import type { WorkflowV3 } from 'features/nodes/types/workflow';
import { type FieldInputTemplate, type FieldOutputTemplate, zStatefulFieldValue } from 'features/nodes/types/field';
import { type InvocationTemplate, type NodeExecutionState, zAnyEdge, zAnyNode } from 'features/nodes/types/invocation';
import { zWorkflowV3 } from 'features/nodes/types/workflow';
import z from 'zod';
export type Templates = Record<string, InvocationTemplate>;
export type NodeExecutionStates = Record<string, NodeExecutionState | undefined>;
@@ -13,11 +14,13 @@ export type PendingConnection = {
fieldTemplate: FieldInputTemplate | FieldOutputTemplate;
};
export type WorkflowMode = 'edit' | 'view';
export type NodesState = {
_version: 1;
nodes: AnyNode[];
edges: AnyEdge[];
formFieldInitialValues: Record<string, StatefulFieldValue>;
} & Omit<WorkflowV3, 'nodes' | 'edges' | 'is_published'>;
export const zWorkflowMode = z.enum(['edit', 'view']);
export type WorkflowMode = z.infer<typeof zWorkflowMode>;
export const zNodesState = z.object({
_version: z.literal(1),
nodes: z.array(zAnyNode),
edges: z.array(zAnyEdge),
formFieldInitialValues: z.record(z.string(), zStatefulFieldValue),
...zWorkflowV3.omit({ nodes: true, edges: true, is_published: true }).shape,
});
export type NodesState = z.infer<typeof zNodesState>;

View File

@@ -1,34 +1,43 @@
import type { PayloadAction, Selector } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import type { WorkflowMode } from 'features/nodes/store/types';
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { type WorkflowMode, zWorkflowMode } from 'features/nodes/store/types';
import type { WorkflowCategory } from 'features/nodes/types/workflow';
import { atom, computed } from 'nanostores';
import type { SQLiteDirection, WorkflowRecordOrderBy } from 'services/api/types';
import {
type SQLiteDirection,
type WorkflowRecordOrderBy,
zSQLiteDirection,
zWorkflowRecordOrderBy,
} from 'services/api/types';
import z from 'zod';
export type WorkflowLibraryView = 'recent' | 'yours' | 'private' | 'shared' | 'defaults' | 'published';
const zWorkflowLibraryView = z.enum(['recent', 'yours', 'private', 'shared', 'defaults', 'published']);
export type WorkflowLibraryView = z.infer<typeof zWorkflowLibraryView>;
type WorkflowLibraryState = {
mode: WorkflowMode;
view: WorkflowLibraryView;
orderBy: WorkflowRecordOrderBy;
direction: SQLiteDirection;
searchTerm: string;
selectedTags: string[];
};
const zWorkflowLibraryState = z.object({
mode: zWorkflowMode,
view: zWorkflowLibraryView,
orderBy: zWorkflowRecordOrderBy,
direction: zSQLiteDirection,
searchTerm: z.string(),
selectedTags: z.array(z.string()),
});
type WorkflowLibraryState = z.infer<typeof zWorkflowLibraryState>;
const initialWorkflowLibraryState: WorkflowLibraryState = {
const getInitialState = (): WorkflowLibraryState => ({
mode: 'view',
searchTerm: '',
orderBy: 'opened_at',
direction: 'DESC',
selectedTags: [],
view: 'defaults',
};
});
export const workflowLibrarySlice = createSlice({
const slice = createSlice({
name: 'workflowLibrary',
initialState: initialWorkflowLibraryState,
initialState: getInitialState(),
reducers: {
workflowModeChanged: (state, action: PayloadAction<WorkflowMode>) => {
state.mode = action.payload;
@@ -73,16 +82,15 @@ export const {
workflowLibraryTagToggled,
workflowLibraryTagsReset,
workflowLibraryViewChanged,
} = workflowLibrarySlice.actions;
} = slice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrateWorkflowLibraryState = (state: any): any => state;
export const workflowLibraryPersistConfig: PersistConfig<WorkflowLibraryState> = {
name: workflowLibrarySlice.name,
initialState: initialWorkflowLibraryState,
migrate: migrateWorkflowLibraryState,
persistDenylist: [],
export const workflowLibrarySliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zWorkflowLibraryState,
getInitialState,
persistConfig: {
migrate: (state) => zWorkflowLibraryState.parse(state),
},
};
const selectWorkflowLibrarySlice = (state: RootState) => state.workflowLibrary;

View File

@@ -1,8 +1,10 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit';
import { SelectionMode } from '@xyflow/react';
import type { PersistConfig, RootState } from 'app/store/store';
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { isPlainObject } from 'es-toolkit';
import type { Selector } from 'react-redux';
import { assert } from 'tsafe';
import z from 'zod';
export const zLayeringStrategy = z.enum(['network-simplex', 'longest-path']);
@@ -11,25 +13,28 @@ export const zLayoutDirection = z.enum(['TB', 'LR']);
type LayoutDirection = z.infer<typeof zLayoutDirection>;
export const zNodeAlignment = z.enum(['UL', 'UR', 'DL', 'DR']);
type NodeAlignment = z.infer<typeof zNodeAlignment>;
const zSelectionMode = z.enum(['partial', 'full']);
export type WorkflowSettingsState = {
_version: 1;
shouldShowMinimapPanel: boolean;
layeringStrategy: LayeringStrategy;
nodeSpacing: number;
layerSpacing: number;
layoutDirection: LayoutDirection;
shouldValidateGraph: boolean;
shouldAnimateEdges: boolean;
nodeAlignment: NodeAlignment;
nodeOpacity: number;
shouldSnapToGrid: boolean;
shouldColorEdges: boolean;
shouldShowEdgeLabels: boolean;
selectionMode: SelectionMode;
};
const zWorkflowSettingsState = z.object({
_version: z.literal(1),
shouldShowMinimapPanel: z.boolean(),
layeringStrategy: zLayeringStrategy,
nodeSpacing: z.number(),
layerSpacing: z.number(),
layoutDirection: zLayoutDirection,
shouldValidateGraph: z.boolean(),
shouldAnimateEdges: z.boolean(),
nodeAlignment: zNodeAlignment,
nodeOpacity: z.number(),
shouldSnapToGrid: z.boolean(),
shouldColorEdges: z.boolean(),
shouldShowEdgeLabels: z.boolean(),
selectionMode: zSelectionMode,
});
const initialState: WorkflowSettingsState = {
export type WorkflowSettingsState = z.infer<typeof zWorkflowSettingsState>;
const getInitialState = (): WorkflowSettingsState => ({
_version: 1,
shouldShowMinimapPanel: true,
layeringStrategy: 'network-simplex',
@@ -43,12 +48,12 @@ const initialState: WorkflowSettingsState = {
shouldColorEdges: true,
shouldShowEdgeLabels: false,
nodeOpacity: 1,
selectionMode: SelectionMode.Partial,
};
selectionMode: 'partial',
});
export const workflowSettingsSlice = createSlice({
const slice = createSlice({
name: 'workflowSettings',
initialState,
initialState: getInitialState(),
reducers: {
shouldShowMinimapPanelChanged: (state, action: PayloadAction<boolean>) => {
state.shouldShowMinimapPanel = action.payload;
@@ -87,7 +92,7 @@ export const workflowSettingsSlice = createSlice({
state.nodeAlignment = action.payload;
},
selectionModeChanged: (state, action: PayloadAction<boolean>) => {
state.selectionMode = action.payload ? SelectionMode.Full : SelectionMode.Partial;
state.selectionMode = action.payload ? 'full' : 'partial';
},
},
});
@@ -106,21 +111,21 @@ export const {
shouldValidateGraphChanged,
nodeOpacityChanged,
selectionModeChanged,
} = workflowSettingsSlice.actions;
} = slice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrateWorkflowSettingsState = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
return state;
};
export const workflowSettingsPersistConfig: PersistConfig<WorkflowSettingsState> = {
name: workflowSettingsSlice.name,
initialState,
migrate: migrateWorkflowSettingsState,
persistDenylist: [],
export const workflowSettingsSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zWorkflowSettingsState,
getInitialState,
persistConfig: {
migrate: (state) => {
assert(isPlainObject(state));
if (!('_version' in state)) {
state._version = 1;
}
return zWorkflowSettingsState.parse(state);
},
},
};
export const selectWorkflowSettingsSlice = (state: RootState) => state.workflowSettings;

View File

@@ -92,7 +92,7 @@ export const zMainModelBase = z.enum([
]);
type MainModelBase = z.infer<typeof zMainModelBase>;
export const isMainModelBase = (base: unknown): base is MainModelBase => zMainModelBase.safeParse(base).success;
const zModelType = z.enum([
export const zModelType = z.enum([
'main',
'vae',
'lora',

View File

@@ -43,7 +43,7 @@ export const zNotesNodeData = z.object({
isOpen: z.boolean(),
notes: z.string(),
});
const _zCurrentImageNodeData = z.object({
const zCurrentImageNodeData = z.object({
id: z.string().trim().min(1),
type: z.literal('current_image'),
label: z.string(),
@@ -52,12 +52,35 @@ const _zCurrentImageNodeData = z.object({
export type NotesNodeData = z.infer<typeof zNotesNodeData>;
export type InvocationNodeData = z.infer<typeof zInvocationNodeData>;
type CurrentImageNodeData = z.infer<typeof _zCurrentImageNodeData>;
type CurrentImageNodeData = z.infer<typeof zCurrentImageNodeData>;
export type InvocationNode = Node<InvocationNodeData, 'invocation'>;
export type NotesNode = Node<NotesNodeData, 'notes'>;
export type CurrentImageNode = Node<CurrentImageNodeData, 'current_image'>;
export type AnyNode = InvocationNode | NotesNode | CurrentImageNode;
const zInvocationNodeValidationSchema = z.looseObject({
type: z.literal('invocation'),
data: zInvocationNodeData,
});
const zInvocationNode = z.custom<Node<InvocationNodeData, 'invocation'>>(
(val) => zInvocationNodeValidationSchema.safeParse(val).success
);
export type InvocationNode = z.infer<typeof zInvocationNode>;
const zNotesNodeValidationSchema = z.looseObject({
type: z.literal('notes'),
data: zNotesNodeData,
});
const zNotesNode = z.custom<Node<NotesNodeData, 'notes'>>((val) => zNotesNodeValidationSchema.safeParse(val).success);
export type NotesNode = z.infer<typeof zNotesNode>;
const zCurrentImageNodeValidationSchema = z.looseObject({
type: z.literal('current_image'),
data: zCurrentImageNodeData,
});
const zCurrentImageNode = z.custom<Node<CurrentImageNodeData, 'current_image'>>(
(val) => zCurrentImageNodeValidationSchema.safeParse(val).success
);
export type CurrentImageNode = z.infer<typeof zCurrentImageNode>;
export const zAnyNode = z.union([zInvocationNode, zNotesNode, zCurrentImageNode]);
export type AnyNode = z.infer<typeof zAnyNode>;
export const isInvocationNode = (node?: AnyNode | null): node is InvocationNode =>
Boolean(node && node.type === 'invocation');
@@ -83,13 +106,29 @@ export type NodeExecutionState = z.infer<typeof _zNodeExecutionState>;
// #endregion
// #region Edges
const _zInvocationNodeEdgeCollapsedData = z.object({
const zDefaultInvocationNodeEdgeValidationSchema = z.looseObject({
type: z.literal('default'),
});
const zDefaultInvocationNodeEdge = z.custom<Edge<Record<string, never>, 'default'>>(
(val) => zDefaultInvocationNodeEdgeValidationSchema.safeParse(val).success
);
export type DefaultInvocationNodeEdge = z.infer<typeof zDefaultInvocationNodeEdge>;
const zInvocationNodeEdgeCollapsedData = z.object({
count: z.number().int().min(1),
});
type InvocationNodeEdgeCollapsedData = z.infer<typeof _zInvocationNodeEdgeCollapsedData>;
export type DefaultInvocationNodeEdge = Edge<Record<string, never>, 'default'>;
export type CollapsedInvocationNodeEdge = Edge<InvocationNodeEdgeCollapsedData, 'collapsed'>;
export type AnyEdge = DefaultInvocationNodeEdge | CollapsedInvocationNodeEdge;
const zInvocationNodeEdgeCollapsedValidationSchema = z.looseObject({
type: z.literal('default'),
data: zInvocationNodeEdgeCollapsedData,
});
type InvocationNodeEdgeCollapsedData = z.infer<typeof zInvocationNodeEdgeCollapsedData>;
const zCollapsedInvocationNodeEdge = z.custom<Edge<InvocationNodeEdgeCollapsedData, 'collapsed'>>(
(val) => zInvocationNodeEdgeCollapsedValidationSchema.safeParse(val).success
);
export type CollapsedInvocationNodeEdge = z.infer<typeof zCollapsedInvocationNodeEdge>;
export const zAnyEdge = z.union([zDefaultInvocationNodeEdge, zCollapsedInvocationNodeEdge]);
export type AnyEdge = z.infer<typeof zAnyEdge>;
// #endregion
export const isBatchNodeType = (type: string) =>

View File

@@ -4,6 +4,7 @@ import { range } from 'es-toolkit/compat';
import type { SeedBehaviour } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import { API_BASE_MODELS } from 'features/parameters/types/constants';
import type { components } from 'services/api/schema';
import type { Batch, EnqueueBatchArg, Invocation } from 'services/api/types';
import { assert } from 'tsafe';
@@ -18,7 +19,7 @@ const getExtendedPrompts = (arg: {
// Normally, the seed behaviour implicity determines the batch size. But when we use models without seeds (like
// ChatGPT 4o) in conjunction with the per-prompt seed behaviour, we lose out on that implicit batch size. To rectify
// this, we need to create a batch of the right size by repeating the prompts.
if (seedBehaviour === 'PER_PROMPT' || model.base === 'chatgpt-4o' || model.base === 'flux-kontext') {
if (seedBehaviour === 'PER_PROMPT' || API_BASE_MODELS.includes(model.base)) {
return range(iterations).flatMap(() => prompts);
}
return prompts;

View File

@@ -1,4 +1,5 @@
import { FormControl, FormLabel } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { selectBase } from 'features/controlLayers/store/paramsSlice';
@@ -6,13 +7,35 @@ import { ModelPicker } from 'features/parameters/components/ModelPicker';
import { selectTileControlNetModel, tileControlnetModelChanged } from 'features/parameters/store/upscaleSlice';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/api/endpoints/models';
import { useControlNetModels } from 'services/api/hooks/modelsByType';
import type { ControlNetModelConfig } from 'services/api/types';
import { type ControlNetModelConfig, isControlNetModelConfig } from 'services/api/types';
const selectTileControlNetModelConfig = createSelector(
selectModelConfigsQuery,
selectTileControlNetModel,
(modelConfigs, modelIdentifierField) => {
if (!modelConfigs.data) {
return null;
}
if (!modelIdentifierField) {
return null;
}
const modelConfig = modelConfigsAdapterSelectors.selectById(modelConfigs.data, modelIdentifierField.key);
if (!modelConfig) {
return null;
}
if (!isControlNetModelConfig(modelConfig)) {
return null;
}
return modelConfig;
}
);
const ParamTileControlNetModel = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const tileControlNetModel = useAppSelector(selectTileControlNetModel);
const tileControlNetModel = useAppSelector(selectTileControlNetModelConfig);
const currentBaseModel = useAppSelector(selectBase);
const [modelConfigs, { isLoading }] = useControlNetModels();

View File

@@ -1,21 +1,21 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import type { ImageWithDims } from 'features/controlLayers/store/types';
import { selectUpscaleSlice } from 'features/parameters/store/upscaleSlice';
import { selectConfigSlice } from 'features/system/store/configSlice';
import { useMemo } from 'react';
import type { ImageDTO } from 'services/api/types';
const createIsTooLargeToUpscaleSelector = (imageDTO?: ImageDTO | null) =>
const createIsTooLargeToUpscaleSelector = (imageWithDims?: ImageWithDims | null) =>
createSelector(selectUpscaleSlice, selectConfigSlice, (upscale, config) => {
const { upscaleModel, scale } = upscale;
const { maxUpscaleDimension } = config;
if (!maxUpscaleDimension || !upscaleModel || !imageDTO) {
if (!maxUpscaleDimension || !upscaleModel || !imageWithDims) {
// When these are missing, another warning will be shown
return false;
}
const { width, height } = imageDTO;
const { width, height } = imageWithDims;
const maxPixels = maxUpscaleDimension ** 2;
const upscaledPixels = width * scale * height * scale;
@@ -23,7 +23,7 @@ const createIsTooLargeToUpscaleSelector = (imageDTO?: ImageDTO | null) =>
return upscaledPixels > maxPixels;
});
export const useIsTooLargeToUpscale = (imageDTO?: ImageDTO | null) => {
const selectIsTooLargeToUpscale = useMemo(() => createIsTooLargeToUpscaleSelector(imageDTO), [imageDTO]);
export const useIsTooLargeToUpscale = (imageWithDims?: ImageWithDims | null) => {
const selectIsTooLargeToUpscale = useMemo(() => createIsTooLargeToUpscaleSelector(imageWithDims), [imageWithDims]);
return useAppSelector(selectIsTooLargeToUpscale);
};

View File

@@ -1,24 +1,33 @@
import type { PayloadAction, Selector } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { isPlainObject } from 'es-toolkit';
import type { ImageWithDims } from 'features/controlLayers/store/types';
import { zImageWithDims } from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common';
import type { ParameterSpandrelImageToImageModel } from 'features/parameters/types/parameterSchemas';
import type { ControlNetModelConfig, ImageDTO } from 'services/api/types';
import type { ControlNetModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import z from 'zod';
export interface UpscaleState {
_version: 1;
upscaleModel: ParameterSpandrelImageToImageModel | null;
upscaleInitialImage: ImageDTO | null;
structure: number;
creativity: number;
tileControlnetModel: ControlNetModelConfig | null;
scale: number;
postProcessingModel: ParameterSpandrelImageToImageModel | null;
tileSize: number;
tileOverlap: number;
}
const zUpscaleState = z.object({
_version: z.literal(2),
upscaleModel: zModelIdentifierField.nullable(),
upscaleInitialImage: zImageWithDims.nullable(),
structure: z.number(),
creativity: z.number(),
tileControlnetModel: zModelIdentifierField.nullable(),
scale: z.number(),
postProcessingModel: zModelIdentifierField.nullable(),
tileSize: z.number(),
tileOverlap: z.number(),
});
const initialUpscaleState: UpscaleState = {
_version: 1,
export type UpscaleState = z.infer<typeof zUpscaleState>;
const getInitialState = (): UpscaleState => ({
_version: 2,
upscaleModel: null,
upscaleInitialImage: null,
structure: 0,
@@ -28,16 +37,19 @@ const initialUpscaleState: UpscaleState = {
postProcessingModel: null,
tileSize: 1024,
tileOverlap: 128,
};
});
export const upscaleSlice = createSlice({
const slice = createSlice({
name: 'upscale',
initialState: initialUpscaleState,
initialState: getInitialState(),
reducers: {
upscaleModelChanged: (state, action: PayloadAction<ParameterSpandrelImageToImageModel | null>) => {
state.upscaleModel = action.payload;
const result = zUpscaleState.shape.upscaleModel.safeParse(action.payload);
if (result.success) {
state.upscaleModel = result.data;
}
},
upscaleInitialImageChanged: (state, action: PayloadAction<ImageDTO | null>) => {
upscaleInitialImageChanged: (state, action: PayloadAction<ImageWithDims | null>) => {
state.upscaleInitialImage = action.payload;
},
structureChanged: (state, action: PayloadAction<number>) => {
@@ -47,13 +59,19 @@ export const upscaleSlice = createSlice({
state.creativity = action.payload;
},
tileControlnetModelChanged: (state, action: PayloadAction<ControlNetModelConfig | null>) => {
state.tileControlnetModel = action.payload;
const result = zUpscaleState.shape.tileControlnetModel.safeParse(action.payload);
if (result.success) {
state.tileControlnetModel = result.data;
}
},
scaleChanged: (state, action: PayloadAction<number>) => {
state.scale = action.payload;
},
postProcessingModelChanged: (state, action: PayloadAction<ParameterSpandrelImageToImageModel | null>) => {
state.postProcessingModel = action.payload;
const result = zUpscaleState.shape.postProcessingModel.safeParse(action.payload);
if (result.success) {
state.postProcessingModel = result.data;
}
},
tileSizeChanged: (state, action: PayloadAction<number>) => {
state.tileSize = action.payload;
@@ -74,21 +92,33 @@ export const {
postProcessingModelChanged,
tileSizeChanged,
tileOverlapChanged,
} = upscaleSlice.actions;
} = slice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrateUpscaleState = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
return state;
};
export const upscalePersistConfig: PersistConfig<UpscaleState> = {
name: upscaleSlice.name,
initialState: initialUpscaleState,
migrate: migrateUpscaleState,
persistDenylist: [],
export const upscaleSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zUpscaleState,
getInitialState,
persistConfig: {
migrate: (state) => {
assert(isPlainObject(state));
if (!('_version' in state)) {
state._version = 1;
}
if (state._version === 1) {
state._version = 2;
// Migrate from v1 to v2: upscaleInitialImage was an ImageDTO, now it's an ImageWithDims
if (state.upscaleInitialImage) {
const { image_name, width, height } = state.upscaleInitialImage;
state.upscaleInitialImage = {
image_name,
width,
height,
};
}
}
return zUpscaleState.parse(state);
},
},
};
export const selectUpscaleSlice = (state: RootState) => state.upscale;

View File

@@ -13,14 +13,13 @@ export const CancelAllExceptCurrentButton = memo((props: ButtonProps) => {
<Button
isDisabled={api.isDisabled}
isLoading={api.isLoading}
aria-label={t('queue.clear')}
tooltip={t('queue.cancelAllExceptCurrentTooltip')}
leftIcon={<PiXCircle />}
colorScheme="error"
onClick={api.openDialog}
{...props}
>
{t('queue.clear')}
{t('queue.cancelAllExceptCurrentTooltip')}
</Button>
);
});

View File

@@ -0,0 +1,29 @@
import type { ButtonProps } from '@invoke-ai/ui-library';
import { Button } from '@invoke-ai/ui-library';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiTrashBold } from 'react-icons/pi';
import { useClearQueueDialog } from './ClearQueueConfirmationAlertDialog';
export const ClearQueueButton = memo((props: ButtonProps) => {
const { t } = useTranslation();
const api = useClearQueueDialog();
return (
<Button
isDisabled={api.isDisabled}
isLoading={api.isLoading}
aria-label={t('queue.clear')}
tooltip={t('queue.clearTooltip')}
leftIcon={<PiTrashBold />}
colorScheme="error"
onClick={api.openDialog}
{...props}
>
{t('queue.clear')}
</Button>
);
});
ClearQueueButton.displayName = 'ClearQueueButton';

View File

@@ -7,7 +7,7 @@ import { useTranslation } from 'react-i18next';
const [useClearQueueConfirmationAlertDialog] = buildUseBoolean(false);
const useClearQueueDialog = () => {
export const useClearQueueDialog = () => {
const dialog = useClearQueueConfirmationAlertDialog();
const clearQueue = useClearQueue();

View File

@@ -9,15 +9,19 @@ import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { navigationApi } from 'features/ui/layouts/navigation-api';
import { memo, useCallback, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import { PiListBold, PiPauseFill, PiPlayFill, PiQueueBold, PiXBold, PiXCircle } from 'react-icons/pi';
import { PiListBold, PiPauseFill, PiPlayFill, PiQueueBold, PiTrashBold, PiXBold, PiXCircle } from 'react-icons/pi';
import { useClearQueueDialog } from './ClearQueueConfirmationAlertDialog';
export const QueueActionsMenuButton = memo(() => {
const ref = useRef<HTMLDivElement>(null);
const { t } = useTranslation();
const isPauseEnabled = useFeatureStatus('pauseQueue');
const isResumeEnabled = useFeatureStatus('resumeQueue');
const isClearAllEnabled = useFeatureStatus('cancelAndClearAll');
const cancelAllExceptCurrent = useCancelAllExceptCurrentQueueItemDialog();
const cancelCurrentQueueItem = useCancelCurrentQueueItem();
const clearQueue = useClearQueueDialog();
const resumeProcessor = useResumeProcessor();
const pauseProcessor = usePauseProcessor();
const openQueue = useCallback(() => {
@@ -55,6 +59,17 @@ export const QueueActionsMenuButton = memo(() => {
>
{t('queue.cancelAllExceptCurrentTooltip')}
</MenuItem>
{isClearAllEnabled && (
<MenuItem
isDestructive
icon={<PiTrashBold />}
onClick={clearQueue.openDialog}
isLoading={clearQueue.isLoading}
isDisabled={clearQueue.isDisabled}
>
{t('queue.clearTooltip')}
</MenuItem>
)}
{isResumeEnabled && (
<MenuItem
icon={<PiPlayFill />}

View File

@@ -4,6 +4,7 @@ import { memo } from 'react';
import { CancelAllExceptCurrentButton } from './CancelAllExceptCurrentButton';
import ClearModelCacheButton from './ClearModelCacheButton';
import { ClearQueueButton } from './ClearQueueButton';
import PauseProcessorButton from './PauseProcessorButton';
import PruneQueueButton from './PruneQueueButton';
import ResumeProcessorButton from './ResumeProcessorButton';
@@ -11,19 +12,20 @@ import ResumeProcessorButton from './ResumeProcessorButton';
const QueueTabQueueControls = () => {
const isPauseEnabled = useFeatureStatus('pauseQueue');
const isResumeEnabled = useFeatureStatus('resumeQueue');
const isClearQueueEnabled = useFeatureStatus('cancelAndClearAll');
return (
<Flex flexDir="column" layerStyle="first" borderRadius="base" p={2} gap={2}>
<Flex gap={2}>
{(isPauseEnabled || isResumeEnabled) && (
<ButtonGroup w={28} orientation="vertical" size="sm">
<ButtonGroup orientation="vertical" size="sm">
{isResumeEnabled && <ResumeProcessorButton />}
{isPauseEnabled && <PauseProcessorButton />}
</ButtonGroup>
)}
<ButtonGroup w={28} orientation="vertical" size="sm">
<ButtonGroup orientation="vertical" size="sm">
<PruneQueueButton />
<CancelAllExceptCurrentButton />
{isClearQueueEnabled ? <ClearQueueButton /> : <CancelAllExceptCurrentButton />}
</ButtonGroup>
</Flex>
<ClearModelCacheButton />

View File

@@ -1,24 +1,27 @@
import type { PayloadAction, Selector } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit';
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import z from 'zod';
interface QueueState {
listCursor: number | undefined;
listPriority: number | undefined;
selectedQueueItem: string | undefined;
resumeProcessorOnEnqueue: boolean;
}
const zQueueState = z.object({
listCursor: z.number().optional(),
listPriority: z.number().optional(),
selectedQueueItem: z.string().optional(),
resumeProcessorOnEnqueue: z.boolean(),
});
type QueueState = z.infer<typeof zQueueState>;
const initialQueueState: QueueState = {
const getInitialState = (): QueueState => ({
listCursor: undefined,
listPriority: undefined,
selectedQueueItem: undefined,
resumeProcessorOnEnqueue: true,
};
});
export const queueSlice = createSlice({
const slice = createSlice({
name: 'queue',
initialState: initialQueueState,
initialState: getInitialState(),
reducers: {
listCursorChanged: (state, action: PayloadAction<number | undefined>) => {
state.listCursor = action.payload;
@@ -33,7 +36,13 @@ export const queueSlice = createSlice({
},
});
export const { listCursorChanged, listPriorityChanged, listParamsReset } = queueSlice.actions;
export const { listCursorChanged, listPriorityChanged, listParamsReset } = slice.actions;
export const queueSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zQueueState,
getInitialState,
};
const selectQueueSlice = (state: RootState) => state.queue;
const createQueueSelector = <T>(selector: Selector<QueueState, T>) => createSelector(selectQueueSlice, selector);

View File

@@ -1,6 +1,7 @@
import { Flex, Text } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { UploadImageIconButton } from 'common/hooks/useImageUploadButton';
import { imageDTOToImageWithDims } from 'features/controlLayers/store/util';
import type { SetUpscaleInitialImageDndTargetData } from 'features/dnd/dnd';
import { setUpscaleInitialImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
@@ -10,11 +11,13 @@ import { selectUpscaleInitialImage, upscaleInitialImageChanged } from 'features/
import { t } from 'i18next';
import { useCallback, useMemo } from 'react';
import { PiArrowCounterClockwiseBold } from 'react-icons/pi';
import { useImageDTO } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
export const UpscaleInitialImage = () => {
const dispatch = useAppDispatch();
const imageDTO = useAppSelector(selectUpscaleInitialImage);
const upscaleInitialImage = useAppSelector(selectUpscaleInitialImage);
const imageDTO = useImageDTO(upscaleInitialImage?.image_name);
const dndTargetData = useMemo<SetUpscaleInitialImageDndTargetData>(
() => setUpscaleInitialImageDndTarget.getData(),
[]
@@ -26,7 +29,7 @@ export const UpscaleInitialImage = () => {
const onUpload = useCallback(
(imageDTO: ImageDTO) => {
dispatch(upscaleInitialImageChanged(imageDTO));
dispatch(upscaleInitialImageChanged(imageDTOToImageWithDims(imageDTO)));
},
[dispatch]
);

View File

@@ -31,8 +31,10 @@ export const UpscaleWarning = () => {
const validModel = modelConfigs.find((cnetModel) => {
return cnetModel.base === model?.base && cnetModel.name.toLowerCase().includes('tile');
});
dispatch(tileControlnetModelChanged(validModel || null));
}, [model?.base, modelConfigs, dispatch]);
if (tileControlnetModel?.key !== validModel?.key) {
dispatch(tileControlnetModelChanged(validModel || null));
}
}, [dispatch, model?.base, modelConfigs, tileControlnetModel?.key]);
const isBaseModelCompatible = useMemo(() => {
return model && ['sd-1', 'sdxl'].includes(model.base);

View File

@@ -1,23 +1,33 @@
import type { PayloadAction, Selector } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { isPlainObject } from 'es-toolkit';
import { paramsReset } from 'features/controlLayers/store/paramsSlice';
import { atom } from 'nanostores';
import { stylePresetsApi } from 'services/api/endpoints/stylePresets';
import { assert } from 'tsafe';
import z from 'zod';
import type { StylePresetState } from './types';
const zStylePresetState = z.object({
activeStylePresetId: z.string().nullable(),
searchTerm: z.string(),
viewMode: z.boolean(),
showPromptPreviews: z.boolean(),
});
const initialState: StylePresetState = {
type StylePresetState = z.infer<typeof zStylePresetState>;
const getInitialState = (): StylePresetState => ({
activeStylePresetId: null,
searchTerm: '',
viewMode: false,
showPromptPreviews: false,
};
});
export const stylePresetSlice = createSlice({
const slice = createSlice({
name: 'stylePreset',
initialState: initialState,
initialState: getInitialState(),
reducers: {
activeStylePresetIdChanged: (state, action: PayloadAction<string | null>) => {
state.activeStylePresetId = action.payload;
@@ -34,7 +44,7 @@ export const stylePresetSlice = createSlice({
},
extraReducers(builder) {
builder.addCase(paramsReset, () => {
return deepClone(initialState);
return getInitialState();
});
builder.addMatcher(stylePresetsApi.endpoints.deleteStylePreset.matchFulfilled, (state, action) => {
if (state.activeStylePresetId === null) {
@@ -58,21 +68,21 @@ export const stylePresetSlice = createSlice({
});
export const { activeStylePresetIdChanged, searchTermChanged, viewModeChanged, showPromptPreviewsChanged } =
stylePresetSlice.actions;
slice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrateStylePresetState = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
return state;
};
export const stylePresetPersistConfig: PersistConfig<StylePresetState> = {
name: stylePresetSlice.name,
initialState,
migrate: migrateStylePresetState,
persistDenylist: [],
export const stylePresetSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zStylePresetState,
getInitialState,
persistConfig: {
migrate: (state) => {
assert(isPlainObject(state));
if (!('_version' in state)) {
state._version = 1;
}
return zStylePresetState.parse(state);
},
},
};
export const selectStylePresetSlice = (state: RootState) => state.stylePreset;

View File

@@ -1,6 +0,0 @@
export type StylePresetState = {
activeStylePresetId: string | null;
searchTerm: string;
viewMode: boolean;
showPromptPreviews: boolean;
};

View File

@@ -14,11 +14,11 @@ import {
Switch,
Text,
} from '@invoke-ai/ui-library';
import { useClearStorage } from 'app/contexts/clear-storage-context';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { buildUseBoolean } from 'common/hooks/useBoolean';
import { useClearStorage } from 'common/hooks/useClearStorage';
import { selectShouldUseCPUNoise, shouldUseCpuNoiseChanged } from 'features/controlLayers/store/paramsSlice';
import { useRefreshAfterResetModal } from 'features/system/components/SettingsModal/RefreshAfterResetModal';
import { SettingsDeveloperLogIsEnabled } from 'features/system/components/SettingsModal/SettingsDeveloperLogIsEnabled';

View File

@@ -1,193 +1,25 @@
import type { PayloadAction, Selector } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit';
import type { RootState } from 'app/store/store';
import type { AppConfig, NumericalParameterConfig, PartialAppConfig } from 'app/types/invokeai';
import type { SliceConfig } from 'app/store/types';
import { getDefaultAppConfig, type PartialAppConfig, zAppConfig } from 'app/types/invokeai';
import { merge } from 'es-toolkit/compat';
import z from 'zod';
const baseDimensionConfig: NumericalParameterConfig = {
initial: 512, // determined by model selection, unused in practice
sliderMin: 64,
sliderMax: 1536,
numberInputMin: 64,
numberInputMax: 4096,
fineStep: 8,
coarseStep: 64,
};
const zConfigState = z.object({
...zAppConfig.shape,
didLoad: z.boolean(),
});
type ConfigState = z.infer<typeof zConfigState>;
const initialConfigState: AppConfig & { didLoad: boolean } = {
const getInitialState = (): ConfigState => ({
...getDefaultAppConfig(),
didLoad: false,
isLocal: true,
shouldUpdateImagesOnConnect: false,
shouldFetchMetadataFromApi: false,
allowPrivateBoards: false,
allowPrivateStylePresets: false,
allowClientSideUpload: false,
allowPublishWorkflows: false,
allowPromptExpansion: false,
shouldShowCredits: false,
disabledTabs: [],
disabledFeatures: ['lightbox', 'faceRestore', 'batches'],
disabledSDFeatures: ['variation', 'symmetry', 'hires', 'perlinNoise', 'noiseThreshold'],
nodesAllowlist: undefined,
nodesDenylist: undefined,
sd: {
disabledControlNetModels: [],
disabledControlNetProcessors: [],
iterations: {
initial: 1,
sliderMin: 1,
sliderMax: 1000,
numberInputMin: 1,
numberInputMax: 10000,
fineStep: 1,
coarseStep: 1,
},
width: { ...baseDimensionConfig },
height: { ...baseDimensionConfig },
boundingBoxWidth: { ...baseDimensionConfig },
boundingBoxHeight: { ...baseDimensionConfig },
scaledBoundingBoxWidth: { ...baseDimensionConfig },
scaledBoundingBoxHeight: { ...baseDimensionConfig },
scheduler: 'dpmpp_3m_k',
vaePrecision: 'fp32',
steps: {
initial: 30,
sliderMin: 1,
sliderMax: 100,
numberInputMin: 1,
numberInputMax: 500,
fineStep: 1,
coarseStep: 1,
},
guidance: {
initial: 7,
sliderMin: 1,
sliderMax: 20,
numberInputMin: 1,
numberInputMax: 200,
fineStep: 0.1,
coarseStep: 0.5,
},
img2imgStrength: {
initial: 0.7,
sliderMin: 0,
sliderMax: 1,
numberInputMin: 0,
numberInputMax: 1,
fineStep: 0.01,
coarseStep: 0.05,
},
canvasCoherenceStrength: {
initial: 0.3,
sliderMin: 0,
sliderMax: 1,
numberInputMin: 0,
numberInputMax: 1,
fineStep: 0.01,
coarseStep: 0.05,
},
hrfStrength: {
initial: 0.45,
sliderMin: 0,
sliderMax: 1,
numberInputMin: 0,
numberInputMax: 1,
fineStep: 0.01,
coarseStep: 0.05,
},
canvasCoherenceEdgeSize: {
initial: 16,
sliderMin: 0,
sliderMax: 128,
numberInputMin: 0,
numberInputMax: 1024,
fineStep: 8,
coarseStep: 16,
},
cfgRescaleMultiplier: {
initial: 0,
sliderMin: 0,
sliderMax: 0.99,
numberInputMin: 0,
numberInputMax: 0.99,
fineStep: 0.05,
coarseStep: 0.1,
},
clipSkip: {
initial: 0,
sliderMin: 0,
sliderMax: 12, // determined by model selection, unused in practice
numberInputMin: 0,
numberInputMax: 12, // determined by model selection, unused in practice
fineStep: 1,
coarseStep: 1,
},
infillPatchmatchDownscaleSize: {
initial: 1,
sliderMin: 1,
sliderMax: 10,
numberInputMin: 1,
numberInputMax: 10,
fineStep: 1,
coarseStep: 1,
},
infillTileSize: {
initial: 32,
sliderMin: 16,
sliderMax: 64,
numberInputMin: 16,
numberInputMax: 256,
fineStep: 1,
coarseStep: 1,
},
maskBlur: {
initial: 16,
sliderMin: 0,
sliderMax: 128,
numberInputMin: 0,
numberInputMax: 512,
fineStep: 1,
coarseStep: 1,
},
ca: {
weight: {
initial: 1,
sliderMin: 0,
sliderMax: 2,
numberInputMin: -1,
numberInputMax: 2,
fineStep: 0.01,
coarseStep: 0.05,
},
},
dynamicPrompts: {
maxPrompts: {
initial: 100,
sliderMin: 1,
sliderMax: 1000,
numberInputMin: 1,
numberInputMax: 10000,
fineStep: 1,
coarseStep: 10,
},
},
},
flux: {
guidance: {
initial: 4,
sliderMin: 2,
sliderMax: 6,
numberInputMin: 1,
numberInputMax: 20,
fineStep: 0.1,
coarseStep: 0.5,
},
},
};
});
export const configSlice = createSlice({
const slice = createSlice({
name: 'config',
initialState: initialConfigState,
initialState: getInitialState(),
reducers: {
configChanged: (state, action: PayloadAction<PartialAppConfig>) => {
merge(state, action.payload);
@@ -196,11 +28,16 @@ export const configSlice = createSlice({
},
});
export const { configChanged } = configSlice.actions;
export const { configChanged } = slice.actions;
export const configSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zConfigState,
getInitialState,
};
export const selectConfigSlice = (state: RootState) => state.config;
const createConfigSelector = <T>(selector: Selector<typeof initialConfigState, T>) =>
createSelector(selectConfigSlice, selector);
const createConfigSelector = <T>(selector: Selector<ConfigState, T>) => createSelector(selectConfigSlice, selector);
export const selectWidthConfig = createConfigSelector((config) => config.sd.width);
export const selectHeightConfig = createConfigSelector((config) => config.sd.height);

View File

@@ -3,12 +3,15 @@ import { createSelector, createSlice } from '@reduxjs/toolkit';
import type { LogNamespace } from 'app/logging/logger';
import { zLogNamespace } from 'app/logging/logger';
import { EMPTY_ARRAY } from 'app/store/constants';
import type { PersistConfig, RootState } from 'app/store/store';
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { isPlainObject } from 'es-toolkit';
import { uniq } from 'es-toolkit/compat';
import { assert } from 'tsafe';
import type { Language, SystemState } from './types';
import { type Language, type SystemState, zSystemState } from './types';
const initialSystemState: SystemState = {
const getInitialState = (): SystemState => ({
_version: 2,
shouldConfirmOnDelete: true,
shouldAntialiasProgressImage: false,
@@ -23,11 +26,11 @@ const initialSystemState: SystemState = {
logNamespaces: [...zLogNamespace.options],
shouldShowInvocationProgressDetail: false,
shouldHighlightFocusedRegions: false,
};
});
export const systemSlice = createSlice({
const slice = createSlice({
name: 'system',
initialState: initialSystemState,
initialState: getInitialState(),
reducers: {
setShouldConfirmOnDelete: (state, action: PayloadAction<boolean>) => {
state.shouldConfirmOnDelete = action.payload;
@@ -89,25 +92,25 @@ export const {
shouldConfirmOnNewSessionToggled,
setShouldShowInvocationProgressDetail,
setShouldHighlightFocusedRegions,
} = systemSlice.actions;
} = slice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrateSystemState = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
if (state._version === 1) {
state.language = (state as SystemState).language.replace('_', '-');
state._version = 2;
}
return state;
};
export const systemPersistConfig: PersistConfig<SystemState> = {
name: systemSlice.name,
initialState: initialSystemState,
migrate: migrateSystemState,
persistDenylist: [],
export const systemSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zSystemState,
getInitialState,
persistConfig: {
migrate: (state) => {
assert(isPlainObject(state));
if (!('_version' in state)) {
state._version = 1;
}
if (state._version === 1) {
state.language = (state as SystemState).language.replace('_', '-');
state._version = 2;
}
return zSystemState.parse(state);
},
},
};
export const selectSystemSlice = (state: RootState) => state.system;

View File

@@ -1,4 +1,4 @@
import type { LogLevel, LogNamespace } from 'app/logging/logger';
import { zLogLevel, zLogNamespace } from 'app/logging/logger';
import { z } from 'zod';
const zLanguage = z.enum([
@@ -29,19 +29,20 @@ const zLanguage = z.enum([
export type Language = z.infer<typeof zLanguage>;
export const isLanguage = (v: unknown): v is Language => zLanguage.safeParse(v).success;
export interface SystemState {
_version: 2;
shouldConfirmOnDelete: boolean;
shouldAntialiasProgressImage: boolean;
shouldConfirmOnNewSession: boolean;
language: Language;
shouldUseNSFWChecker: boolean;
shouldUseWatermarker: boolean;
shouldEnableInformationalPopovers: boolean;
shouldEnableModelDescriptions: boolean;
logIsEnabled: boolean;
logLevel: LogLevel;
logNamespaces: LogNamespace[];
shouldShowInvocationProgressDetail: boolean;
shouldHighlightFocusedRegions: boolean;
}
export const zSystemState = z.object({
_version: z.literal(2),
shouldConfirmOnDelete: z.boolean(),
shouldAntialiasProgressImage: z.boolean(),
shouldConfirmOnNewSession: z.boolean(),
language: zLanguage,
shouldUseNSFWChecker: z.boolean(),
shouldUseWatermarker: z.boolean(),
shouldEnableInformationalPopovers: z.boolean(),
shouldEnableModelDescriptions: z.boolean(),
logIsEnabled: z.boolean(),
logLevel: zLogLevel,
logNamespaces: z.array(zLogNamespace),
shouldShowInvocationProgressDetail: z.boolean(),
shouldHighlightFocusedRegions: z.boolean(),
});
export type SystemState = z.infer<typeof zSystemState>;

View File

@@ -1,6 +1,7 @@
import { Box, Button, ButtonGroup, Flex, Grid, Heading, Icon, Text } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
import { imageDTOToImageWithDims } from 'features/controlLayers/store/util';
import { setUpscaleInitialImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
import {
@@ -37,7 +38,7 @@ export const UpscalingLaunchpadPanel = memo(() => {
const onUpload = useCallback(
(imageDTO: ImageDTO) => {
dispatch(upscaleInitialImageChanged(imageDTO));
dispatch(upscaleInitialImageChanged(imageDTOToImageWithDims(imageDTO)));
},
[dispatch]
);

View File

@@ -1,5 +1,5 @@
import type { DockviewApi, GridviewApi } from 'dockview';
import { DockviewPanel, GridviewPanel } from 'dockview';
import { DockviewApi as MockedDockviewApi, DockviewPanel, GridviewPanel } from 'dockview';
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
import type { NavigationAppApi } from './navigation-api';
@@ -12,6 +12,7 @@ import {
RIGHT_PANEL_MIN_SIZE_PX,
SETTINGS_PANEL_ID,
SWITCH_TABS_FAKE_DELAY_MS,
VIEWER_PANEL_ID,
WORKSPACE_PANEL_ID,
} from './shared';
@@ -48,7 +49,7 @@ vi.mock('dockview', async () => {
}
}
// Mock GridviewPanel class for instanceof checks
// Mock DockviewPanel class for instanceof checks
class MockDockviewPanel {
api = {
setActive: vi.fn(),
@@ -58,10 +59,21 @@ vi.mock('dockview', async () => {
};
}
// Mock DockviewApi class for instanceof checks
class MockDockviewApi {
panels = [];
activePanel = null;
toJSON = vi.fn();
fromJSON = vi.fn();
onDidLayoutChange = vi.fn();
onDidActivePanelChange = vi.fn();
}
return {
...actual,
GridviewPanel: MockGridviewPanel,
DockviewPanel: MockDockviewPanel,
DockviewApi: MockDockviewApi,
};
});
@@ -1105,4 +1117,393 @@ describe('AppNavigationApi', () => {
expect(initialize).not.toHaveBeenCalled();
});
});
describe('toggleViewerPanel', () => {
beforeEach(() => {
navigationApi.connectToApp(mockAppApi);
});
it('should switch to viewer panel when not currently on viewer', async () => {
const mockViewerPanel = createMockDockPanel();
navigationApi._registerPanel('generate', VIEWER_PANEL_ID, mockViewerPanel);
mockGetAppTab.mockReturnValue('generate');
// Set current panel to something other than viewer
navigationApi._currentActiveDockviewPanel.set('generate', SETTINGS_PANEL_ID);
const result = await navigationApi.toggleViewerPanel();
expect(result).toBe(true);
expect(mockViewerPanel.api.setActive).toHaveBeenCalledOnce();
});
it('should switch to previous panel when on viewer and previous panel exists', async () => {
const mockPreviousPanel = createMockDockPanel();
const mockViewerPanel = createMockDockPanel();
navigationApi._registerPanel('generate', SETTINGS_PANEL_ID, mockPreviousPanel);
navigationApi._registerPanel('generate', VIEWER_PANEL_ID, mockViewerPanel);
mockGetAppTab.mockReturnValue('generate');
// Set current panel to viewer and previous to settings
navigationApi._currentActiveDockviewPanel.set('generate', VIEWER_PANEL_ID);
navigationApi._prevActiveDockviewPanel.set('generate', SETTINGS_PANEL_ID);
const result = await navigationApi.toggleViewerPanel();
expect(result).toBe(true);
expect(mockPreviousPanel.api.setActive).toHaveBeenCalledOnce();
expect(mockViewerPanel.api.setActive).not.toHaveBeenCalled();
});
it('should switch to launchpad when on viewer and no valid previous panel', async () => {
const mockLaunchpadPanel = createMockDockPanel();
const mockViewerPanel = createMockDockPanel();
navigationApi._registerPanel('generate', LAUNCHPAD_PANEL_ID, mockLaunchpadPanel);
navigationApi._registerPanel('generate', VIEWER_PANEL_ID, mockViewerPanel);
mockGetAppTab.mockReturnValue('generate');
// Set current panel to viewer and no previous panel
navigationApi._currentActiveDockviewPanel.set('generate', VIEWER_PANEL_ID);
navigationApi._prevActiveDockviewPanel.set('generate', null);
const result = await navigationApi.toggleViewerPanel();
expect(result).toBe(true);
expect(mockLaunchpadPanel.api.setActive).toHaveBeenCalledOnce();
expect(mockViewerPanel.api.setActive).not.toHaveBeenCalled();
});
it('should switch to launchpad when on viewer and previous panel is also viewer', async () => {
const mockLaunchpadPanel = createMockDockPanel();
const mockViewerPanel = createMockDockPanel();
navigationApi._registerPanel('generate', LAUNCHPAD_PANEL_ID, mockLaunchpadPanel);
navigationApi._registerPanel('generate', VIEWER_PANEL_ID, mockViewerPanel);
mockGetAppTab.mockReturnValue('generate');
// Set current panel to viewer and previous panel was also viewer
navigationApi._currentActiveDockviewPanel.set('generate', VIEWER_PANEL_ID);
navigationApi._prevActiveDockviewPanel.set('generate', VIEWER_PANEL_ID);
const result = await navigationApi.toggleViewerPanel();
expect(result).toBe(true);
expect(mockLaunchpadPanel.api.setActive).toHaveBeenCalledOnce();
expect(mockViewerPanel.api.setActive).not.toHaveBeenCalled();
});
it('should return false when no active tab', async () => {
mockGetAppTab.mockReturnValue(null);
const result = await navigationApi.toggleViewerPanel();
expect(result).toBe(false);
});
it('should return false when viewer panel is not registered', async () => {
mockGetAppTab.mockReturnValue('generate');
navigationApi._currentActiveDockviewPanel.set('generate', SETTINGS_PANEL_ID);
// Don't register viewer panel
const result = await navigationApi.toggleViewerPanel();
expect(result).toBe(false);
});
it('should return false when previous panel is not registered', async () => {
const mockViewerPanel = createMockDockPanel();
navigationApi._registerPanel('generate', VIEWER_PANEL_ID, mockViewerPanel);
mockGetAppTab.mockReturnValue('generate');
// Set current to viewer and previous to unregistered panel
navigationApi._currentActiveDockviewPanel.set('generate', VIEWER_PANEL_ID);
navigationApi._prevActiveDockviewPanel.set('generate', 'unregistered-panel');
const result = await navigationApi.toggleViewerPanel();
expect(result).toBe(false);
});
it('should return false when launchpad panel is not registered as fallback', async () => {
const mockViewerPanel = createMockDockPanel();
navigationApi._registerPanel('generate', VIEWER_PANEL_ID, mockViewerPanel);
mockGetAppTab.mockReturnValue('generate');
// Set current to viewer and no previous panel, but don't register launchpad
navigationApi._currentActiveDockviewPanel.set('generate', VIEWER_PANEL_ID);
navigationApi._prevActiveDockviewPanel.set('generate', null);
const result = await navigationApi.toggleViewerPanel();
expect(result).toBe(false);
});
it('should work across different tabs independently', async () => {
const mockViewerPanel1 = createMockDockPanel();
const mockViewerPanel2 = createMockDockPanel();
const mockSettingsPanel1 = createMockDockPanel();
const mockSettingsPanel2 = createMockDockPanel();
const mockLaunchpadPanel = createMockDockPanel();
navigationApi._registerPanel('generate', VIEWER_PANEL_ID, mockViewerPanel1);
navigationApi._registerPanel('generate', SETTINGS_PANEL_ID, mockSettingsPanel1);
navigationApi._registerPanel('canvas', VIEWER_PANEL_ID, mockViewerPanel2);
navigationApi._registerPanel('canvas', SETTINGS_PANEL_ID, mockSettingsPanel2);
navigationApi._registerPanel('canvas', LAUNCHPAD_PANEL_ID, mockLaunchpadPanel);
// Set up different states for different tabs
navigationApi._currentActiveDockviewPanel.set('generate', SETTINGS_PANEL_ID);
navigationApi._currentActiveDockviewPanel.set('canvas', VIEWER_PANEL_ID);
navigationApi._prevActiveDockviewPanel.set('canvas', SETTINGS_PANEL_ID);
// Test generate tab (should switch to viewer)
mockGetAppTab.mockReturnValue('generate');
const result1 = await navigationApi.toggleViewerPanel();
expect(result1).toBe(true);
expect(mockViewerPanel1.api.setActive).toHaveBeenCalledOnce();
// Test canvas tab (should switch to previous panel - settings panel in canvas)
mockGetAppTab.mockReturnValue('canvas');
const result2 = await navigationApi.toggleViewerPanel();
expect(result2).toBe(true);
expect(mockSettingsPanel2.api.setActive).toHaveBeenCalledOnce();
});
it('should handle sequence of viewer toggles correctly', async () => {
const mockViewerPanel = createMockDockPanel();
const mockSettingsPanel = createMockDockPanel();
const mockLaunchpadPanel = createMockDockPanel();
navigationApi._registerPanel('generate', VIEWER_PANEL_ID, mockViewerPanel);
navigationApi._registerPanel('generate', SETTINGS_PANEL_ID, mockSettingsPanel);
navigationApi._registerPanel('generate', LAUNCHPAD_PANEL_ID, mockLaunchpadPanel);
mockGetAppTab.mockReturnValue('generate');
// Start on settings panel
navigationApi._currentActiveDockviewPanel.set('generate', SETTINGS_PANEL_ID);
navigationApi._prevActiveDockviewPanel.set('generate', null);
// First toggle: settings -> viewer
const result1 = await navigationApi.toggleViewerPanel();
expect(result1).toBe(true);
expect(mockViewerPanel.api.setActive).toHaveBeenCalledOnce();
// Simulate panel change tracking (normally done by dockview listener)
navigationApi._prevActiveDockviewPanel.set('generate', SETTINGS_PANEL_ID);
navigationApi._currentActiveDockviewPanel.set('generate', VIEWER_PANEL_ID);
// Second toggle: viewer -> settings (previous panel)
const result2 = await navigationApi.toggleViewerPanel();
expect(result2).toBe(true);
expect(mockSettingsPanel.api.setActive).toHaveBeenCalledOnce();
// Simulate panel change tracking again
navigationApi._prevActiveDockviewPanel.set('generate', VIEWER_PANEL_ID);
navigationApi._currentActiveDockviewPanel.set('generate', SETTINGS_PANEL_ID);
// Third toggle: settings -> viewer again
const result3 = await navigationApi.toggleViewerPanel();
expect(result3).toBe(true);
expect(mockViewerPanel.api.setActive).toHaveBeenCalledTimes(2);
});
});
describe('Disposable Cleanup', () => {
beforeEach(() => {
navigationApi.connectToApp(mockAppApi);
});
it('should add disposable functions for a tab', () => {
const dispose1 = vi.fn();
const dispose2 = vi.fn();
navigationApi._addDisposeForTab('generate', dispose1);
navigationApi._addDisposeForTab('generate', dispose2);
// Check that disposables are stored
const disposables = navigationApi._disposablesForTab.get('generate');
expect(disposables).toBeDefined();
expect(disposables?.size).toBe(2);
expect(disposables?.has(dispose1)).toBe(true);
expect(disposables?.has(dispose2)).toBe(true);
});
it('should handle multiple tabs independently', () => {
const dispose1 = vi.fn();
const dispose2 = vi.fn();
const dispose3 = vi.fn();
navigationApi._addDisposeForTab('generate', dispose1);
navigationApi._addDisposeForTab('generate', dispose2);
navigationApi._addDisposeForTab('canvas', dispose3);
const generateDisposables = navigationApi._disposablesForTab.get('generate');
const canvasDisposables = navigationApi._disposablesForTab.get('canvas');
expect(generateDisposables?.size).toBe(2);
expect(canvasDisposables?.size).toBe(1);
expect(generateDisposables?.has(dispose1)).toBe(true);
expect(generateDisposables?.has(dispose2)).toBe(true);
expect(canvasDisposables?.has(dispose3)).toBe(true);
});
it('should call all dispose functions when unregistering a tab', () => {
const dispose1 = vi.fn();
const dispose2 = vi.fn();
const dispose3 = vi.fn();
// Add disposables for generate tab
navigationApi._addDisposeForTab('generate', dispose1);
navigationApi._addDisposeForTab('generate', dispose2);
// Add disposable for canvas tab (should not be called)
navigationApi._addDisposeForTab('canvas', dispose3);
// Unregister generate tab
navigationApi.unregisterTab('generate');
// Check that generate tab disposables were called
expect(dispose1).toHaveBeenCalledOnce();
expect(dispose2).toHaveBeenCalledOnce();
// Check that canvas tab disposable was not called
expect(dispose3).not.toHaveBeenCalled();
// Check that generate tab disposables are cleared
expect(navigationApi._disposablesForTab.has('generate')).toBe(false);
// Check that canvas tab disposables remain
expect(navigationApi._disposablesForTab.has('canvas')).toBe(true);
});
it('should handle unregistering tab with no disposables gracefully', () => {
// Should not throw when unregistering tab with no disposables
expect(() => navigationApi.unregisterTab('generate')).not.toThrow();
});
it('should handle duplicate dispose functions', () => {
const dispose1 = vi.fn();
// Add the same dispose function twice
navigationApi._addDisposeForTab('generate', dispose1);
navigationApi._addDisposeForTab('generate', dispose1);
const disposables = navigationApi._disposablesForTab.get('generate');
// Set should contain only one instance (sets don't allow duplicates)
expect(disposables?.size).toBe(1);
navigationApi.unregisterTab('generate');
// Should be called only once despite being added twice
expect(dispose1).toHaveBeenCalledOnce();
});
it('should automatically add dispose functions during container registration with DockviewApi', () => {
const tab = 'generate';
const viewId = 'myView';
mockGetStorage.mockReturnValue(undefined);
const initialize = vi.fn();
const panel = { id: 'p1' };
const mockDispose = vi.fn();
// Create a mock that will pass the instanceof DockviewApi check
const mockApi = Object.create(MockedDockviewApi.prototype);
Object.assign(mockApi, {
panels: [panel],
activePanel: { id: 'p1' },
toJSON: vi.fn(() => ({ foo: 'bar' })),
onDidLayoutChange: vi.fn(() => ({ dispose: vi.fn() })),
onDidActivePanelChange: vi.fn(() => ({ dispose: mockDispose })),
});
navigationApi.registerContainer(tab, viewId, mockApi, initialize);
// Check that dispose function was added to disposables
const disposables = navigationApi._disposablesForTab.get(tab);
expect(disposables).toBeDefined();
expect(disposables?.size).toBe(1);
// Unregister tab and check dispose was called
navigationApi.unregisterTab(tab);
expect(mockDispose).toHaveBeenCalledOnce();
});
it('should not add dispose functions for GridviewApi during container registration', () => {
const tab = 'generate';
const viewId = 'myView';
mockGetStorage.mockReturnValue(undefined);
const initialize = vi.fn();
const panel = { id: 'p1' };
// Mock GridviewApi (not DockviewApi)
const mockApi = {
panels: [panel],
toJSON: vi.fn(() => ({ foo: 'bar' })),
onDidLayoutChange: vi.fn(() => ({ dispose: vi.fn() })),
} as unknown as GridviewApi;
navigationApi.registerContainer(tab, viewId, mockApi, initialize);
// Check that no dispose function was added for GridviewApi
const disposables = navigationApi._disposablesForTab.get(tab);
expect(disposables).toBeUndefined();
});
it('should handle dispose function errors gracefully', () => {
const goodDispose = vi.fn();
const errorDispose = vi.fn(() => {
throw new Error('Dispose error');
});
const anotherGoodDispose = vi.fn();
navigationApi._addDisposeForTab('generate', goodDispose);
navigationApi._addDisposeForTab('generate', errorDispose);
navigationApi._addDisposeForTab('generate', anotherGoodDispose);
// Should not throw even if one dispose function throws
expect(() => navigationApi.unregisterTab('generate')).not.toThrow();
// All dispose functions should have been called
expect(goodDispose).toHaveBeenCalledOnce();
expect(errorDispose).toHaveBeenCalledOnce();
expect(anotherGoodDispose).toHaveBeenCalledOnce();
});
it('should clear panel tracking state when unregistering tab', () => {
const tab = 'generate';
// Set up some panel tracking state
navigationApi._currentActiveDockviewPanel.set(tab, VIEWER_PANEL_ID);
navigationApi._prevActiveDockviewPanel.set(tab, SETTINGS_PANEL_ID);
// Add some disposables
const dispose1 = vi.fn();
const dispose2 = vi.fn();
navigationApi._addDisposeForTab(tab, dispose1);
navigationApi._addDisposeForTab(tab, dispose2);
// Verify state exists before unregistering
expect(navigationApi._currentActiveDockviewPanel.has(tab)).toBe(true);
expect(navigationApi._prevActiveDockviewPanel.has(tab)).toBe(true);
expect(navigationApi._disposablesForTab.has(tab)).toBe(true);
// Unregister tab
navigationApi.unregisterTab(tab);
// Verify all state is cleared
expect(navigationApi._currentActiveDockviewPanel.has(tab)).toBe(false);
expect(navigationApi._prevActiveDockviewPanel.has(tab)).toBe(false);
expect(navigationApi._disposablesForTab.has(tab)).toBe(false);
// Verify dispose functions were called
expect(dispose1).toHaveBeenCalledOnce();
expect(dispose2).toHaveBeenCalledOnce();
});
});
});

View File

@@ -1,19 +1,21 @@
import { logger } from 'app/logging/logger';
import { createDeferredPromise, type Deferred } from 'common/util/createDeferredPromise';
import { parseify } from 'common/util/serialize';
import type { DockviewApi, GridviewApi, IDockviewPanel, IGridviewPanel } from 'dockview';
import { GridviewPanel } from 'dockview';
import type { GridviewApi, IDockviewPanel, IGridviewPanel } from 'dockview';
import { DockviewApi, GridviewPanel } from 'dockview';
import { debounce } from 'es-toolkit';
import type { Serializable, TabName } from 'features/ui/store/uiTypes';
import type { Atom } from 'nanostores';
import { atom } from 'nanostores';
import {
LAUNCHPAD_PANEL_ID,
LEFT_PANEL_ID,
LEFT_PANEL_MIN_SIZE_PX,
RIGHT_PANEL_ID,
RIGHT_PANEL_MIN_SIZE_PX,
SWITCH_TABS_FAKE_DELAY_MS,
VIEWER_PANEL_ID,
} from './shared';
const log = logger('system');
@@ -69,6 +71,37 @@ export class NavigationApi {
private _$isLoading = atom(false);
$isLoading: Atom<boolean> = this._$isLoading;
/**
* Track the _previous_ active dockview panel for each tab.
*/
_prevActiveDockviewPanel: Map<TabName, string | null> = new Map();
/**
* Track the _current_ active dockview panel for each tab.
*/
_currentActiveDockviewPanel: Map<TabName, string | null> = new Map();
/**
* Map of disposables for each tab.
* This is used to clean up resources when a tab is unregistered.
*/
_disposablesForTab: Map<TabName, Set<() => void>> = new Map();
/**
* Convenience method to add a dispose function for a specific tab.
*/
/**
* Convenience method to add a dispose function for a specific tab.
*/
_addDisposeForTab = (tab: TabName, disposeFn: () => void): void => {
let disposables = this._disposablesForTab.get(tab);
if (!disposables) {
disposables = new Set<() => void>();
this._disposablesForTab.set(tab, disposables);
}
disposables.add(disposeFn);
};
/**
* Separator used to create unique keys for panels. Typo protection.
*/
@@ -209,6 +242,18 @@ export class NavigationApi {
this._registerPanel(tab, panel.id, panel);
}
// Set up tracking for active tab for this panel - needed for viewer toggle functionality
if (api instanceof DockviewApi) {
this._currentActiveDockviewPanel.set(tab, api.activePanel?.id ?? null);
this._prevActiveDockviewPanel.set(tab, null);
const { dispose } = api.onDidActivePanelChange((panel) => {
const previousPanelId = this._currentActiveDockviewPanel.get(tab);
this._prevActiveDockviewPanel.set(tab, previousPanelId ?? null);
this._currentActiveDockviewPanel.set(tab, panel?.id ?? null);
});
this._addDisposeForTab(tab, dispose);
}
api.onDidLayoutChange(
debounce(() => {
this._app?.storage.set(key, api.toJSON());
@@ -545,6 +590,42 @@ export class NavigationApi {
return true;
};
/**
* Toggle between the viewer panel and the previously focused dockview panel in the current tab.
* If currently on viewer and a previous panel exists, switch to the previous panel.
* If not on viewer, switch to viewer.
* If no previous panel exists, defaults to launchpad panel.
* Only operates on dockview panels (panels with tabs), not gridview panels.
*
* @returns Promise that resolves to true if successful, false otherwise
*/
toggleViewerPanel = (): Promise<boolean> => {
const activeTab = this._app?.activeTab.get() ?? null;
if (!activeTab) {
log.warn('No active tab found for viewer toggle');
return Promise.resolve(false);
}
const prevActiveDockviewPanel = this._prevActiveDockviewPanel.get(activeTab);
const currentActiveDockviewPanel = this._currentActiveDockviewPanel.get(activeTab);
let targetPanel;
if (currentActiveDockviewPanel !== VIEWER_PANEL_ID) {
targetPanel = VIEWER_PANEL_ID;
} else if (prevActiveDockviewPanel && prevActiveDockviewPanel !== VIEWER_PANEL_ID) {
targetPanel = prevActiveDockviewPanel;
} else {
targetPanel = LAUNCHPAD_PANEL_ID;
}
if (this.getRegisteredPanels(activeTab).includes(targetPanel)) {
return this.focusPanel(activeTab, targetPanel);
}
return Promise.resolve(false);
};
/**
* Check if a panel is registered.
* @param tab - The tab the panel belongs to
@@ -593,6 +674,18 @@ export class NavigationApi {
this.waiters.delete(key);
}
// Clear previous panel tracking for this tab
this._prevActiveDockviewPanel.delete(tab);
this._currentActiveDockviewPanel.delete(tab);
this._disposablesForTab.get(tab)?.forEach((disposeFn) => {
try {
disposeFn();
} catch (error) {
log.error({ error: parseify(error) }, `Error disposing resource for tab ${tab}`);
}
});
this._disposablesForTab.delete(tab);
log.trace(`Unregistered all panels for tab ${tab}`);
};
}

View File

@@ -1,11 +1,13 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { isPlainObject } from 'es-toolkit';
import { assert } from 'tsafe';
import type { UIState } from './uiTypes';
import { getInitialUIState } from './uiTypes';
import { getInitialUIState, type UIState, zUIState } from './uiTypes';
export const uiSlice = createSlice({
const slice = createSlice({
name: 'ui',
initialState: getInitialUIState(),
reducers: {
@@ -81,29 +83,30 @@ export const {
textAreaSizesStateChanged,
dockviewStorageKeyChanged,
pickerCompactViewStateChanged,
} = uiSlice.actions;
} = slice.actions;
export const selectUiSlice = (state: RootState) => state.ui;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrateUIState = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
if (state._version === 1) {
state.activeTab = 'generation';
state._version = 2;
}
if (state._version === 2) {
state.activeTab = 'canvas';
state._version = 3;
}
return state;
};
export const uiPersistConfig: PersistConfig<UIState> = {
name: uiSlice.name,
initialState: getInitialUIState(),
migrate: migrateUIState,
persistDenylist: ['shouldShowImageDetails'],
export const uiSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zUIState,
getInitialState: getInitialUIState,
persistConfig: {
migrate: (state) => {
assert(isPlainObject(state));
if (!('_version' in state)) {
state._version = 1;
}
if (state._version === 1) {
state.activeTab = 'generation';
state._version = 2;
}
if (state._version === 2) {
state.activeTab = 'canvas';
state._version = 3;
}
return zUIState.parse(state);
},
persistDenylist: ['shouldShowImageDetails'],
},
};

View File

@@ -1,8 +1,7 @@
import { deepClone } from 'common/util/deepClone';
import { isPlainObject } from 'es-toolkit';
import { z } from 'zod';
const zTabName = z.enum(['generate', 'canvas', 'upscaling', 'workflows', 'models', 'queue']);
export const zTabName = z.enum(['generate', 'canvas', 'upscaling', 'workflows', 'models', 'queue']);
export type TabName = z.infer<typeof zTabName>;
const zPartialDimensions = z.object({
@@ -13,18 +12,28 @@ const zPartialDimensions = z.object({
const zSerializable = z.any().refine(isPlainObject);
export type Serializable = z.infer<typeof zSerializable>;
const zUIState = z.object({
_version: z.literal(3).default(3),
activeTab: zTabName.default('generate'),
shouldShowImageDetails: z.boolean().default(false),
shouldShowProgressInViewer: z.boolean().default(true),
accordions: z.record(z.string(), z.boolean()).default(() => ({})),
expanders: z.record(z.string(), z.boolean()).default(() => ({})),
textAreaSizes: z.record(z.string(), zPartialDimensions).default({}),
panels: z.record(z.string(), zSerializable).default({}),
shouldShowNotificationV2: z.boolean().default(true),
pickerCompactViewStates: z.record(z.string(), z.boolean()).default(() => ({})),
export const zUIState = z.object({
_version: z.literal(3),
activeTab: zTabName,
shouldShowImageDetails: z.boolean(),
shouldShowProgressInViewer: z.boolean(),
accordions: z.record(z.string(), z.boolean()),
expanders: z.record(z.string(), z.boolean()),
textAreaSizes: z.record(z.string(), zPartialDimensions),
panels: z.record(z.string(), zSerializable),
shouldShowNotificationV2: z.boolean(),
pickerCompactViewStates: z.record(z.string(), z.boolean()),
});
const INITIAL_STATE = zUIState.parse({});
export type UIState = z.infer<typeof zUIState>;
export const getInitialUIState = (): UIState => deepClone(INITIAL_STATE);
export const getInitialUIState = (): UIState => ({
_version: 3 as const,
activeTab: 'generate' as const,
shouldShowImageDetails: false,
shouldShowProgressInViewer: true,
accordions: {},
expanders: {},
textAreaSizes: {},
panels: {},
shouldShowNotificationV2: true,
pickerCompactViewStates: {},
});

View File

@@ -1,5 +1,6 @@
import { $openAPISchemaUrl } from 'app/store/nanostores/openAPISchemaUrl';
import type { OpenAPIV3_1 } from 'openapi-types';
import type { stringify } from 'querystring';
import type { paths } from 'services/api/schema';
import type { AppConfig, AppVersion } from 'services/api/types';
@@ -11,7 +12,8 @@ import { api, buildV1Url } from '..';
* buildAppInfoUrl('some-path')
* // '/api/v1/app/some-path'
*/
const buildAppInfoUrl = (path: string = '') => buildV1Url(`app/${path}`);
export const buildAppInfoUrl = (path: string = '', query?: Parameters<typeof stringify>[0]) =>
buildV1Url(`app/${path}`, query);
export const appInfoApi = api.injectEndpoints({
endpoints: (build) => ({
@@ -87,6 +89,31 @@ export const appInfoApi = api.injectEndpoints({
},
providesTags: ['Schema'],
}),
getClientStateByKey: build.query<
paths['/api/v1/app/client_state']['get']['responses']['200']['content']['application/json'],
paths['/api/v1/app/client_state']['get']['parameters']['query']
>({
query: () => ({
url: buildAppInfoUrl('client_state'),
method: 'GET',
}),
}),
setClientStateByKey: build.mutation<
paths['/api/v1/app/client_state']['post']['responses']['200']['content']['application/json'],
paths['/api/v1/app/client_state']['post']['requestBody']['content']['application/json']
>({
query: (body) => ({
url: buildAppInfoUrl('client_state'),
method: 'POST',
body,
}),
}),
deleteClientState: build.mutation<void, void>({
query: () => ({
url: buildAppInfoUrl('client_state'),
method: 'DELETE',
}),
}),
}),
});

View File

@@ -57,13 +57,18 @@ const tagTypes = [
// This is invalidated on reconnect. It should be used for queries that have changing data,
// especially related to the queue and generation.
'FetchOnReconnect',
'ClientState',
] as const;
export type ApiTagDescription = TagDescription<(typeof tagTypes)[number]>;
export const LIST_TAG = 'LIST';
export const LIST_ALL_TAG = 'LIST_ALL';
const dynamicBaseQuery: BaseQueryFn<string | FetchArgs, unknown, FetchBaseQueryError> = (args, api, extraOptions) => {
export const getBaseUrl = (): string => {
const baseUrl = $baseUrl.get();
return baseUrl || window.location.href.replace(/\/$/, '');
};
const dynamicBaseQuery: BaseQueryFn<string | FetchArgs, unknown, FetchBaseQueryError> = (args, api, extraOptions) => {
const authToken = $authToken.get();
const projectId = $projectId.get();
const isOpenAPIRequest =
@@ -71,7 +76,7 @@ const dynamicBaseQuery: BaseQueryFn<string | FetchArgs, unknown, FetchBaseQueryE
(typeof args === 'string' && args.includes('openapi.json'));
const fetchBaseQueryArgs: FetchBaseQueryArgs = {
baseUrl: baseUrl || window.location.href.replace(/\/$/, ''),
baseUrl: getBaseUrl(),
};
// When fetching the openapi.json, we need to remove circular references from the JSON.

View File

@@ -1164,6 +1164,34 @@ export type paths = {
patch?: never;
trace?: never;
};
"/api/v1/app/client_state": {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
/**
* Get Client State By Key
* @description Gets the client state
*/
get: operations["get_client_state_by_key"];
put?: never;
/**
* Set Client State
* @description Sets the client state
*/
post: operations["set_client_state"];
/**
* Delete Client State
* @description Deletes the client state
*/
delete: operations["delete_client_state"];
options?: never;
head?: never;
patch?: never;
trace?: never;
};
"/api/v1/queue/{queue_id}/enqueue_batch": {
parameters: {
query?: never;
@@ -24697,6 +24725,101 @@ export interface operations {
};
};
};
get_client_state_by_key: {
parameters: {
query: {
/** @description Key to get */
key: string;
};
header?: never;
path?: never;
cookie?: never;
};
requestBody?: never;
responses: {
/** @description Successful Response */
200: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["JsonValue"] | null;
};
};
/** @description Validation Error */
422: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["HTTPValidationError"];
};
};
};
};
set_client_state: {
parameters: {
query: {
/** @description Key to set */
key: string;
};
header?: never;
path?: never;
cookie?: never;
};
requestBody: {
content: {
"application/json": components["schemas"]["JsonValue"];
};
};
responses: {
/** @description Successful Response */
200: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": unknown;
};
};
/** @description Validation Error */
422: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["HTTPValidationError"];
};
};
};
};
delete_client_state: {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
requestBody?: never;
responses: {
/** @description Successful Response */
200: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": unknown;
};
};
/** @description Client state deleted */
204: {
headers: {
[name: string]: unknown;
};
content?: never;
};
};
};
enqueue_batch: {
parameters: {
query?: never;

Some files were not shown because too many files have changed in this diff Show More