Compare commits

..

1 Commits

Author SHA1 Message Date
psychedelicious
63d22336f6 chore: bump version to v6.1.0 2025-07-22 08:12:31 +10:00
104 changed files with 1344 additions and 2963 deletions

View File

@@ -10,7 +10,6 @@ from invokeai.app.services.board_images.board_images_default import BoardImagesS
from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage
from invokeai.app.services.boards.boards_default import BoardService
from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService
from invokeai.app.services.client_state_persistence.client_state_persistence_sqlite import ClientStatePersistenceSqlite
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.download.download_default import DownloadQueueService
from invokeai.app.services.events.events_fastapievents import FastAPIEventService
@@ -152,7 +151,6 @@ class ApiDependencies:
style_preset_records = SqliteStylePresetRecordsStorage(db=db)
style_preset_image_files = StylePresetImageFileStorageDisk(style_presets_folder / "images")
workflow_thumbnails = WorkflowThumbnailFileStorageDisk(workflow_thumbnails_folder)
client_state_persistence = ClientStatePersistenceSqlite(db=db)
services = InvocationServices(
board_image_records=board_image_records,
@@ -183,7 +181,6 @@ class ApiDependencies:
style_preset_records=style_preset_records,
style_preset_image_files=style_preset_image_files,
workflow_thumbnails=workflow_thumbnails,
client_state_persistence=client_state_persistence,
)
ApiDependencies.invoker = Invoker(services)

View File

@@ -5,9 +5,9 @@ from pathlib import Path
from typing import Optional
import torch
from fastapi import Body, HTTPException, Query
from fastapi import Body
from fastapi.routing import APIRouter
from pydantic import BaseModel, Field, JsonValue
from pydantic import BaseModel, Field
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.invocations.upscale import ESRGAN_MODELS
@@ -173,50 +173,3 @@ async def disable_invocation_cache() -> None:
async def get_invocation_cache_status() -> InvocationCacheStatus:
"""Clears the invocation cache"""
return ApiDependencies.invoker.services.invocation_cache.get_status()
@app_router.get(
"/client_state",
operation_id="get_client_state_by_key",
response_model=JsonValue | None,
)
async def get_client_state_by_key(
key: str = Query(..., description="Key to get"),
) -> JsonValue | None:
"""Gets the client state"""
try:
return ApiDependencies.invoker.services.client_state_persistence.get_by_key(key)
except Exception as e:
logging.error(f"Error getting client state: {e}")
raise HTTPException(status_code=500, detail="Error setting client state")
@app_router.post(
"/client_state",
operation_id="set_client_state",
response_model=None,
)
async def set_client_state(
key: str = Query(..., description="Key to set"),
value: JsonValue = Body(..., description="Value of the key"),
) -> None:
"""Sets the client state"""
try:
ApiDependencies.invoker.services.client_state_persistence.set_by_key(key, value)
except Exception as e:
logging.error(f"Error setting client state: {e}")
raise HTTPException(status_code=500, detail="Error setting client state")
@app_router.delete(
"/client_state",
operation_id="delete_client_state",
responses={204: {"description": "Client state deleted"}},
)
async def delete_client_state() -> None:
"""Deletes the client state"""
try:
ApiDependencies.invoker.services.client_state_persistence.delete()
except Exception as e:
logging.error(f"Error deleting client state: {e}")
raise HTTPException(status_code=500, detail="Error deleting client state")

View File

@@ -1,35 +0,0 @@
from abc import ABC, abstractmethod
from pydantic import JsonValue
class ClientStatePersistenceABC(ABC):
"""
Base class for client persistence implementations.
This class defines the interface for persisting client data.
"""
@abstractmethod
def set_by_key(self, key: str, value: JsonValue) -> None:
"""
Store the data for the client.
:param data: The client data to be stored.
"""
pass
@abstractmethod
def get_by_key(self, key: str) -> JsonValue | None:
"""
Get the data for the client.
:return: The client data.
"""
pass
@abstractmethod
def delete(self) -> None:
"""
Delete the data for the client.
"""
pass

View File

@@ -1,65 +0,0 @@
import json
from pydantic import JsonValue
from invokeai.app.services.client_state_persistence.client_state_persistence_base import ClientStatePersistenceABC
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
class ClientStatePersistenceSqlite(ClientStatePersistenceABC):
"""
Base class for client persistence implementations.
This class defines the interface for persisting client data.
"""
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._db = db
self._default_row_id = 1
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
def set_by_key(self, key: str, value: JsonValue) -> None:
state = self.get() or {}
state.update({key: value})
with self._db.transaction() as cursor:
cursor.execute(
f"""
INSERT INTO client_state (id, data)
VALUES ({self._default_row_id}, ?)
ON CONFLICT(id) DO UPDATE
SET data = excluded.data;
""",
(json.dumps(state),),
)
def get(self) -> dict[str, JsonValue] | None:
with self._db.transaction() as cursor:
cursor.execute(
f"""
SELECT data FROM client_state
WHERE id = {self._default_row_id}
"""
)
row = cursor.fetchone()
if row is None:
return None
return json.loads(row[0])
def get_by_key(self, key: str) -> JsonValue | None:
state = self.get()
if state is None:
return None
return state.get(key, None)
def delete(self) -> None:
with self._db.transaction() as cursor:
cursor.execute(
f"""
DELETE FROM client_state
WHERE id = {self._default_row_id}
"""
)

View File

@@ -17,7 +17,6 @@ if TYPE_CHECKING:
from invokeai.app.services.board_records.board_records_base import BoardRecordStorageBase
from invokeai.app.services.boards.boards_base import BoardServiceABC
from invokeai.app.services.bulk_download.bulk_download_base import BulkDownloadBase
from invokeai.app.services.client_state_persistence.client_state_persistence_base import ClientStatePersistenceABC
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadQueueServiceBase
from invokeai.app.services.events.events_base import EventServiceBase
@@ -74,7 +73,6 @@ class InvocationServices:
style_preset_records: "StylePresetRecordsStorageBase",
style_preset_image_files: "StylePresetImageFileStorageBase",
workflow_thumbnails: "WorkflowThumbnailServiceBase",
client_state_persistence: "ClientStatePersistenceABC",
):
self.board_images = board_images
self.board_image_records = board_image_records
@@ -104,4 +102,3 @@ class InvocationServices:
self.style_preset_records = style_preset_records
self.style_preset_image_files = style_preset_image_files
self.workflow_thumbnails = workflow_thumbnails
self.client_state_persistence = client_state_persistence

View File

@@ -23,7 +23,6 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_17 import
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_18 import build_migration_18
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_19 import build_migration_19
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_20 import build_migration_20
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_21 import build_migration_21
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
@@ -64,7 +63,6 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_18())
migrator.register_migration(build_migration_19(app_config=config))
migrator.register_migration(build_migration_20())
migrator.register_migration(build_migration_21())
migrator.run_migrations()
return db

View File

@@ -1,40 +0,0 @@
import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration21Callback:
def __call__(self, cursor: sqlite3.Cursor) -> None:
cursor.execute(
"""
CREATE TABLE client_state (
id INTEGER PRIMARY KEY CHECK(id = 1),
data TEXT NOT NULL, -- Frontend will handle the shape of this data
updated_at DATETIME NOT NULL DEFAULT (CURRENT_TIMESTAMP)
);
"""
)
cursor.execute(
"""
CREATE TRIGGER tg_client_state_updated_at
AFTER UPDATE ON client_state
FOR EACH ROW
BEGIN
UPDATE client_state
SET updated_at = CURRENT_TIMESTAMP
WHERE id = OLD.id;
END;
"""
)
def build_migration_21() -> Migration:
"""Builds the migration object for migrating from version 20 to version 21. This includes:
- Creating the `client_state` table.
- Adding a trigger to update the `updated_at` field on updates.
"""
return Migration(
from_version=20,
to_version=21,
callback=Migration21Callback(),
)

View File

@@ -44,5 +44,4 @@ yalc.lock
# vitest
tsconfig.vitest-temp.json
coverage/
*.tgz
coverage/

View File

@@ -26,7 +26,7 @@ i18n.use(initReactI18next).init({
returnNull: false,
});
const store = createStore({ driver: { getItem: () => {}, setItem: () => {} }, persistThrottle: 2000 });
const store = createStore(undefined, false);
$store.set(store);
$baseUrl.set('http://localhost:9090');

View File

@@ -197,10 +197,6 @@ export default [
importNames: ['isEqual'],
message: 'Please use objectEquals from @observ33r/object-equals instead.',
},
{
name: 'zod/v3',
message: 'Import from zod instead.',
},
],
},
],

View File

@@ -63,6 +63,7 @@
"framer-motion": "^11.10.0",
"i18next": "^25.3.2",
"i18next-http-backend": "^3.0.2",
"idb-keyval": "6.2.2",
"jsondiffpatch": "^0.7.3",
"konva": "^9.3.22",
"linkify-react": "^4.3.1",
@@ -102,7 +103,7 @@
"use-debounce": "^10.0.5",
"use-device-pixel-ratio": "^1.1.2",
"uuid": "^11.1.0",
"zod": "^4.0.10",
"zod": "^4.0.5",
"zod-validation-error": "^3.5.2"
},
"peerDependencies": {

View File

@@ -80,6 +80,9 @@ importers:
i18next-http-backend:
specifier: ^3.0.2
version: 3.0.2
idb-keyval:
specifier: 6.2.2
version: 6.2.2
jsondiffpatch:
specifier: ^0.7.3
version: 0.7.3
@@ -198,11 +201,11 @@ importers:
specifier: ^11.1.0
version: 11.1.0
zod:
specifier: ^4.0.10
version: 4.0.10
specifier: ^4.0.5
version: 4.0.5
zod-validation-error:
specifier: ^3.5.2
version: 3.5.3(zod@4.0.10)
version: 3.5.3(zod@4.0.5)
devDependencies:
'@eslint/js':
specifier: ^9.31.0
@@ -408,10 +411,6 @@ packages:
resolution: {integrity: sha512-vbavdySgbTTrmFE+EsiqUTzlOr5bzlnJtUv9PynGCAKvfQqjIXbvFdumPM/GxMDfyuGMJaJAU6TO4zc1Jf1i8Q==}
engines: {node: '>=6.9.0'}
'@babel/runtime@7.28.2':
resolution: {integrity: sha512-KHp2IflsnGywDjBWDkR9iEqiWSpc8GIi0lgTT3mOElT0PP1tG26P4tmFI2YvAdzgq9RGyoHZQEIEdZy6Ec5xCA==}
engines: {node: '>=6.9.0'}
'@babel/template@7.27.2':
resolution: {integrity: sha512-LPDZ85aEJyYSd18/DkjNh4/y1ntkE5KwUHWTiqgRxruuZL2F1yuHligVHLvcHY2vMHXttKFpJn6LwfI7cw7ODw==}
engines: {node: '>=6.9.0'}
@@ -2772,6 +2771,9 @@ packages:
typescript:
optional: true
idb-keyval@6.2.2:
resolution: {integrity: sha512-yjD9nARJ/jb1g+CvD0tlhUHOrJ9Sy0P8T9MF3YaLlHnSRpwPfpTX0XIvpmw3gAJUmEu3FiICLBDPXVwyEvrleg==}
ieee754@1.2.1:
resolution: {integrity: sha512-dcyqhDvX1C46lXZcVqCpK+FtMRQVdIMN6/Df5js2zouUsqG7I6sFxitIC+7KYK29KdXOLHdu9zL4sFnoVQnqaA==}
@@ -4509,8 +4511,8 @@ packages:
zod@3.25.76:
resolution: {integrity: sha512-gzUt/qt81nXsFGKIFcC3YnfEAx5NkunCfnDlvuBSSFS02bcXu4Lmea0AFIUwbLWxWPx3d9p8S5QoaujKcNQxcQ==}
zod@4.0.10:
resolution: {integrity: sha512-3vB+UU3/VmLL2lvwcY/4RV2i9z/YU0DTV/tDuYjrwmx5WeJ7hwy+rGEEx8glHp6Yxw7ibRbKSaIFBgReRPe5KA==}
zod@4.0.5:
resolution: {integrity: sha512-/5UuuRPStvHXu7RS+gmvRf4NXrNxpSllGwDnCBcJZtQsKrviYXm54yDGV2KYNLT5kq0lHGcl7lqWJLgSaG+tgA==}
zustand@4.5.7:
resolution: {integrity: sha512-CHOUy7mu3lbD6o6LJLfllpjkzhHXSBlX8B9+qPddUsIfeF5S/UZ5q0kmCsnRqT1UHFQZchNFDDzMbQsuesHWlw==}
@@ -4631,8 +4633,6 @@ snapshots:
'@babel/runtime@7.27.6': {}
'@babel/runtime@7.28.2': {}
'@babel/template@7.27.2':
dependencies:
'@babel/code-frame': 7.27.1
@@ -5736,7 +5736,7 @@ snapshots:
'@testing-library/dom@10.4.0':
dependencies:
'@babel/code-frame': 7.27.1
'@babel/runtime': 7.28.2
'@babel/runtime': 7.27.6
'@types/aria-query': 5.0.4
aria-query: 5.3.0
chalk: 4.1.2
@@ -7266,6 +7266,8 @@ snapshots:
optionalDependencies:
typescript: 5.8.3
idb-keyval@6.2.2: {}
ieee754@1.2.1: {}
ignore@5.3.2: {}
@@ -9060,13 +9062,13 @@ snapshots:
dependencies:
zod: 3.25.76
zod-validation-error@3.5.3(zod@4.0.10):
zod-validation-error@3.5.3(zod@4.0.5):
dependencies:
zod: 4.0.10
zod: 4.0.5
zod@3.25.76: {}
zod@4.0.10: {}
zod@4.0.5: {}
zustand@4.5.7(@types/react@18.3.23)(immer@10.1.1)(react@18.3.1):
dependencies:

View File

@@ -253,7 +253,6 @@
"cancel": "Cancel",
"cancelAllExceptCurrentQueueItemAlertDialog": "Canceling all queue items except the current one will stop pending items but allow the in-progress one to finish.",
"cancelAllExceptCurrentQueueItemAlertDialog2": "Are you sure you want to cancel all pending queue items?",
"cancelAllExceptCurrent": "Cancel All Except Current",
"cancelAllExceptCurrentTooltip": "Cancel All Except Current Item",
"cancelTooltip": "Cancel Current Item",
"cancelSucceeded": "Item Canceled",
@@ -274,7 +273,7 @@
"retryItem": "Retry Item",
"cancelBatchSucceeded": "Batch Canceled",
"cancelBatchFailed": "Problem Canceling Batch",
"clearQueueAlertDialog": "Clearing the queue immediately cancels any processing items and clears the queue entirely. Pending filters will be canceled and the Canvas Staging Area will be reset.",
"clearQueueAlertDialog": "Clearing the queue immediately cancels any processing items and clears the queue entirely. Pending filters will be canceled.",
"clearQueueAlertDialog2": "Are you sure you want to clear the queue?",
"current": "Current",
"next": "Next",
@@ -2631,10 +2630,9 @@
"whatsNew": {
"whatsNewInInvoke": "What's New in Invoke",
"items": [
"New setting to send all Canvas generations directly to the Gallery.",
"New Invert Mask (Shift+V) and Fit BBox to Mask (Shift+B) capabilities.",
"Expanded support for Model Thumbnails and configurations.",
"Various other quality of life updates and fixes"
"Generate images faster with new Launchpads and a simplified Generate tab.",
"Edit with prompts using Flux Kontext Dev.",
"Export to PSD, bulk-hide overlays, organize models & images — all in a reimagined interface built for control."
],
"readReleaseNotes": "Read Release Notes",
"watchRecentReleaseVideos": "Watch Recent Release Videos",

View File

@@ -254,16 +254,12 @@
"desc": "Attiva/disattiva il pannello destro."
},
"resetPanelLayout": {
"title": "Ripristina lo schema del pannello",
"desc": "Ripristina le dimensioni e lo schema predefiniti dei pannelli sinistro e destro."
"title": "Ripristina il layout del pannello",
"desc": "Ripristina le dimensioni e il layout predefiniti dei pannelli sinistro e destro."
},
"togglePanels": {
"title": "Attiva/disattiva i pannelli",
"desc": "Mostra o nascondi contemporaneamente i pannelli sinistro e destro."
},
"selectGenerateTab": {
"title": "Seleziona la scheda Genera",
"desc": "Seleziona la scheda Genera."
}
},
"hotkeys": "Tasti di scelta rapida",
@@ -393,23 +389,6 @@
"behavior": "Comportamento",
"display": "Mostra",
"grid": "Griglia"
},
"invertMask": {
"title": "Inverti maschera",
"desc": "Inverte la maschera di inpaint selezionata, creando una nuova maschera con trasparenza opposta."
},
"fitBboxToMasks": {
"title": "Adatta il riquadro di delimitazione alle maschere",
"desc": "Regola automaticamente il riquadro di delimitazione della generazione per adattarlo alle maschere di inpaint visibili"
},
"applySegmentAnything": {
"title": "Applica Segment Anything",
"desc": "Applica la maschera Segment Anything corrente.",
"key": "invio"
},
"cancelSegmentAnything": {
"title": "Annulla Segment Anything",
"desc": "Annulla l'operazione Segment Anything corrente."
}
},
"workflows": {
@@ -539,10 +518,6 @@
"galleryNavUpAlt": {
"desc": "Uguale a Naviga verso l'alto, ma seleziona l'immagine da confrontare, aprendo la modalità di confronto se non è già aperta.",
"title": "Naviga verso l'alto (Confronta immagine)"
},
"starImage": {
"desc": "Aggiungi/Rimuovi contrassegno all'immagine selezionata.",
"title": "Aggiungi / Rimuovi contrassegno immagine"
}
}
},
@@ -961,15 +936,7 @@
"canvasManagerNotAvailable": "Gestione tela non disponibile",
"promptExpansionFailed": "Abbiamo riscontrato un problema. Riprova a eseguire l'espansione del prompt.",
"uploadAndPromptGenerationFailed": "Impossibile caricare l'immagine e generare il prompt",
"promptGenerationStarted": "Generazione del prompt avviata",
"invalidBboxDesc": "Il riquadro di delimitazione non ha dimensioni valide",
"invalidBbox": "Riquadro di delimitazione non valido",
"noInpaintMaskSelectedDesc": "Seleziona una maschera di inpaint da invertire",
"noInpaintMaskSelected": "Nessuna maschera di inpaint selezionata",
"noVisibleMasksDesc": "Crea o abilita almeno una maschera inpaint da invertire",
"noVisibleMasks": "Nessuna maschera visibile",
"maskInvertFailed": "Impossibile invertire la maschera",
"maskInverted": "Maschera invertita"
"promptGenerationStarted": "Generazione del prompt avviata"
},
"accessibility": {
"invokeProgressBar": "Barra di avanzamento generazione",
@@ -1164,22 +1131,7 @@
"missingField_withName": "Campo \"{{name}}\" mancante",
"unknownFieldEditWorkflowToFix_withName": "Il flusso di lavoro contiene un campo \"{{name}}\" sconosciuto .\nModifica il flusso di lavoro per risolvere il problema.",
"unexpectedField_withName": "Campo \"{{name}}\" inaspettato",
"missingSourceOrTargetHandle": "Identificatore del nodo sorgente o di destinazione mancante",
"layout": {
"alignmentDR": "In basso a destra",
"autoLayout": "Schema automatico",
"nodeSpacing": "Spaziatura nodi",
"layerSpacing": "Spaziatura livelli",
"layeringStrategy": "Strategia livelli",
"longestPath": "Percorso più lungo",
"layoutDirection": "Direzione schema",
"layoutDirectionRight": "Orizzontale",
"layoutDirectionDown": "Verticale",
"alignment": "Allineamento nodi",
"alignmentUL": "In alto a sinistra",
"alignmentDL": "In basso a sinistra",
"alignmentUR": "In alto a destra"
}
"missingSourceOrTargetHandle": "Identificatore del nodo sorgente o di destinazione mancante"
},
"boards": {
"autoAddBoard": "Aggiungi automaticamente bacheca",
@@ -1256,7 +1208,7 @@
"batchQueuedDesc_other": "Aggiunte {{count}} sessioni a {{direction}} della coda",
"graphQueued": "Grafico in coda",
"batch": "Lotto",
"clearQueueAlertDialog": "La cancellazione della coda annulla immediatamente tutti gli elementi in elaborazione e cancella completamente la coda. I filtri in sospeso verranno annullati e l'area di lavoro della Tela verrà reimpostata.",
"clearQueueAlertDialog": "Lo svuotamento della coda annulla immediatamente tutti gli elementi in elaborazione e cancella completamente la coda. I filtri in sospeso verranno annullati.",
"pending": "In attesa",
"completedIn": "Completato in",
"resumeFailed": "Problema nel riavvio dell'elaborazione",
@@ -1312,8 +1264,7 @@
"retrySucceeded": "Elemento rieseguito",
"retryItem": "Riesegui elemento",
"retryFailed": "Problema riesecuzione elemento",
"credits": "Crediti",
"cancelAllExceptCurrent": "Annulla tutto tranne quello corrente"
"credits": "Crediti"
},
"models": {
"noMatchingModels": "Nessun modello corrispondente",
@@ -1728,7 +1679,7 @@
"structure": {
"heading": "Struttura",
"paragraphs": [
"La struttura determina quanto l'immagine finale rispecchierà il layout dell'originale. Un valore struttura basso permette cambiamenti significativi, mentre un valore struttura alto conserva la composizione e lo schema originali."
"La struttura determina quanto l'immagine finale rispecchierà il layout dell'originale. Una struttura bassa permette cambiamenti significativi, mentre una struttura alta conserva la composizione e il layout originali."
]
},
"fluxDevLicense": {
@@ -1894,7 +1845,7 @@
"opened": "Aperto",
"convertGraph": "Converti grafico",
"loadWorkflow": "$t(common.load) Flusso di lavoro",
"autoLayout": "Schema automatico",
"autoLayout": "Disposizione automatica",
"loadFromGraph": "Carica il flusso di lavoro dal grafico",
"userWorkflows": "Flussi di lavoro utente",
"projectWorkflows": "Flussi di lavoro del progetto",
@@ -2493,9 +2444,7 @@
"switchOnStart": "All'inizio",
"switchOnFinish": "Alla fine",
"off": "Spento"
},
"invertMask": "Inverti maschera",
"fitBboxToMasks": "Adatta il riquadro di delimitazione alle maschere"
}
},
"ui": {
"tabs": {
@@ -2648,10 +2597,9 @@
"watchRecentReleaseVideos": "Guarda i video su questa versione",
"watchUiUpdatesOverview": "Guarda le novità dell'interfaccia",
"items": [
"Nuova impostazione per inviare tutte le generazioni della Tela direttamente alla Galleria.",
"Nuove funzionalità Inverti maschera (Maiusc+V) e Adatta il Riquadro di delimitazione alla maschera (Maiusc+B).",
"Supporto esteso per miniature e configurazioni dei modelli.",
"Vari altri aggiornamenti e correzioni per la qualità della vita"
"Genera immagini più velocemente con le nuove Rampe di lancio e una scheda Genera semplificata.",
"Modifica con prompt utilizzando Flux Kontext Dev.",
"Esporta in PSD, nascondi sovrapposizioni in blocco, organizza modelli e immagini: il tutto in un'interfaccia riprogettata e pensata per il controllo."
]
},
"system": {

View File

@@ -2,10 +2,10 @@ import { Box } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { GlobalHookIsolator } from 'app/components/GlobalHookIsolator';
import { GlobalModalIsolator } from 'app/components/GlobalModalIsolator';
import { useClearStorage } from 'app/contexts/clear-storage-context';
import { $didStudioInit, type StudioInitAction } from 'app/hooks/useStudioInitAction';
import type { PartialAppConfig } from 'app/types/invokeai';
import Loading from 'common/components/Loading/Loading';
import { useClearStorage } from 'common/hooks/useClearStorage';
import { AppContent } from 'features/ui/components/AppContent';
import { memo, useCallback } from 'react';
import { ErrorBoundary } from 'react-error-boundary';

View File

@@ -1,12 +1,10 @@
import 'i18n';
import type { Middleware } from '@reduxjs/toolkit';
import { ClearStorageProvider } from 'app/contexts/clear-storage-context';
import type { StudioInitAction } from 'app/hooks/useStudioInitAction';
import { $didStudioInit } from 'app/hooks/useStudioInitAction';
import type { LoggingOverrides } from 'app/logging/logger';
import { $loggingOverrides, configureLogging } from 'app/logging/logger';
import { buildStorageApi } from 'app/store/enhancers/reduxRemember/driver';
import { $accountSettingsLink } from 'app/store/nanostores/accountSettingsLink';
import { $authToken } from 'app/store/nanostores/authToken';
import { $baseUrl } from 'app/store/nanostores/baseUrl';
@@ -72,14 +70,6 @@ interface Props extends PropsWithChildren {
* If provided, overrides in-app navigation to the model manager
*/
onClickGoToModelManager?: () => void;
storageConfig?: {
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
getItem: (key: string) => Promise<any>;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
setItem: (key: string, value: any) => Promise<any>;
clear: () => Promise<void>;
persistThrottle: number;
};
}
const InvokeAIUI = ({
@@ -106,7 +96,6 @@ const InvokeAIUI = ({
loggingOverrides,
onClickGoToModelManager,
whatsNew,
storageConfig,
}: Props) => {
useLayoutEffect(() => {
/*
@@ -319,21 +308,9 @@ const InvokeAIUI = ({
};
}, [isDebugging]);
const storage = useMemo(() => buildStorageApi(storageConfig), [storageConfig]);
useEffect(() => {
const storageCleanup = storage.registerListeners();
return () => {
storageCleanup();
};
}, [storage]);
const store = useMemo(() => {
return createStore({
driver: storage.reduxRememberDriver,
persistThrottle: storageConfig?.persistThrottle ?? 2000,
});
}, [storage.reduxRememberDriver, storageConfig?.persistThrottle]);
return createStore(projectId);
}, [projectId]);
useEffect(() => {
$store.set(store);
@@ -350,13 +327,11 @@ const InvokeAIUI = ({
return (
<React.StrictMode>
<ClearStorageProvider value={storage.clearStorage}>
<Provider store={store}>
<React.Suspense fallback={<Loading />}>
<App config={config} studioInitAction={studioInitAction} />
</React.Suspense>
</Provider>
</ClearStorageProvider>
<Provider store={store}>
<React.Suspense fallback={<Loading />}>
<App config={config} studioInitAction={studioInitAction} />
</React.Suspense>
</Provider>
</React.StrictMode>
);
};

View File

@@ -1,10 +0,0 @@
import { createContext, useContext } from 'react';
const ClearStorageContext = createContext<() => void>(() => {});
export const ClearStorageProvider = ClearStorageContext.Provider;
export const useClearStorage = () => {
const context = useContext(ClearStorageContext);
return context;
};

View File

@@ -1,2 +1,3 @@
export const STORAGE_PREFIX = '@@invokeai-';
export const EMPTY_ARRAY = [];
export const EMPTY_OBJECT = {};

View File

@@ -1,243 +1,40 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import { logger } from 'app/logging/logger';
import { StorageError } from 'app/store/enhancers/reduxRemember/errors';
import { $projectId } from 'app/store/nanostores/projectId';
import type { Driver as ReduxRememberDriver } from 'redux-remember';
import { getBaseUrl } from 'services/api';
import { buildAppInfoUrl } from 'services/api/endpoints/appInfo';
import type { UseStore } from 'idb-keyval';
import { clear, createStore as createIDBKeyValStore, get, set } from 'idb-keyval';
import { atom } from 'nanostores';
import type { Driver } from 'redux-remember';
const log = logger('system');
// Create a custom idb-keyval store (just needed to customize the name)
const $idbKeyValStore = atom<UseStore>(createIDBKeyValStore('invoke', 'invoke-store'));
const buildOSSServerBackedDriver = (): {
reduxRememberDriver: ReduxRememberDriver;
clearStorage: () => Promise<void>;
registerListeners: () => () => void;
} => {
// Persistence happens per slice. To track when persistence is in progress, maintain a ref count, incrementing
// it when a slice is being persisted and decrementing it when the persistence is done.
let persistRefCount = 0;
export const clearIdbKeyValStore = () => {
clear($idbKeyValStore.get());
};
// Keep track of the last persisted state for each key to avoid unnecessary network requests.
//
// `redux-remember` persists individual slices of state, so we can implicity denylist a slice by not giving it a
// persist config.
//
// However, we may need to avoid persisting individual _fields_ of a slice. `redux-remember` does not provide a
// way to do this directly.
//
// To accomplish this, we add a layer of logic on top of the `redux-remember`. In the state serializer function
// provided to `redux-remember`, we can omit certain fields from the state that we do not want to persist. See
// the implementation in `store.ts` for this logic.
//
// This logic is unknown to `redux-remember`. When an omitted field changes, it will still attempt to persist the
// whole slice, even if the final, _serialized_ slice value is unchanged.
//
// To avoid unnecessary network requests, we keep track of the last persisted state for each key. If the value to
// be persisted is the same as the last persisted value, we can skip the network request.
const lastPersistedState = new Map<string, unknown>();
const getUrl = (key?: string) => {
const baseUrl = getBaseUrl();
const query: Record<string, string> = {};
if (key) {
query['key'] = key;
}
const path = buildAppInfoUrl('client_state', query);
const url = `${baseUrl}/${path}`;
return url;
};
const reduxRememberDriver: ReduxRememberDriver = {
getItem: async (key) => {
try {
const url = getUrl(key);
const res = await fetch(url, { method: 'GET' });
if (!res.ok) {
throw new Error(`Response status: ${res.status}`);
}
const text = await res.text();
if (!lastPersistedState.get(key)) {
lastPersistedState.set(key, text);
}
return JSON.parse(text);
} catch (originalError) {
throw new StorageError({
key,
projectId: $projectId.get(),
originalError,
});
}
},
setItem: async (key, value) => {
try {
persistRefCount++;
if (lastPersistedState.get(key) === value) {
log.trace(`Skipping persist for key "${key}" as value is unchanged.`);
return value;
}
const url = getUrl(key);
const headers = new Headers({
'Content-Type': 'application/json',
});
const res = await fetch(url, { method: 'POST', headers, body: value });
if (!res.ok) {
throw new Error(`Response status: ${res.status}`);
}
lastPersistedState.set(key, value);
return value;
} catch (originalError) {
throw new StorageError({
key,
value,
projectId: $projectId.get(),
originalError,
});
} finally {
persistRefCount--;
if (persistRefCount < 0) {
log.trace('Persist ref count is negative, resetting to 0');
persistRefCount = 0;
}
}
},
};
const clearStorage = async () => {
// Create redux-remember driver, wrapping idb-keyval
export const idbKeyValDriver: Driver = {
getItem: (key) => {
try {
persistRefCount++;
const url = getUrl();
const res = await fetch(url, { method: 'DELETE' });
if (!res.ok) {
throw new Error(`Response status: ${res.status}`);
}
} catch {
log.error('Failed to reset client state');
} finally {
persistRefCount--;
lastPersistedState.clear();
if (persistRefCount < 0) {
log.trace('Persist ref count is negative, resetting to 0');
persistRefCount = 0;
}
return get(key, $idbKeyValStore.get());
} catch (originalError) {
throw new StorageError({
key,
projectId: $projectId.get(),
originalError,
});
}
};
const registerListeners = () => {
const onBeforeUnload = (e: BeforeUnloadEvent) => {
if (persistRefCount > 0) {
e.preventDefault();
}
};
window.addEventListener('beforeunload', onBeforeUnload);
return () => {
window.removeEventListener('beforeunload', onBeforeUnload);
};
};
return { reduxRememberDriver, clearStorage, registerListeners };
};
const buildCustomDriver = (api: {
getItem: (key: string) => Promise<any>;
setItem: (key: string, value: any) => Promise<any>;
clear: () => Promise<void>;
}): {
reduxRememberDriver: ReduxRememberDriver;
clearStorage: () => Promise<void>;
registerListeners: () => () => void;
} => {
// See the comment in `buildOSSServerBackedDriver` for an explanation of this variable.
let persistRefCount = 0;
// See the comment in `buildOSSServerBackedDriver` for an explanation of this variable.
const lastPersistedState = new Map<string, unknown>();
const reduxRememberDriver: ReduxRememberDriver = {
getItem: async (key) => {
try {
log.trace(`Getting client state for key "${key}"`);
return await api.getItem(key);
} catch (originalError) {
throw new StorageError({
key,
projectId: $projectId.get(),
originalError,
});
}
},
setItem: async (key, value) => {
try {
persistRefCount++;
if (lastPersistedState.get(key) === value) {
log.trace(`Skipping setting client state for key "${key}" as value is unchanged`);
return value;
}
log.trace(`Setting client state for key "${key}", ${value}`);
await api.setItem(key, value);
lastPersistedState.set(key, value);
return value;
} catch (originalError) {
throw new StorageError({
key,
value,
projectId: $projectId.get(),
originalError,
});
} finally {
persistRefCount--;
if (persistRefCount < 0) {
log.trace('Persist ref count is negative, resetting to 0');
persistRefCount = 0;
}
}
},
};
const clearStorage = async () => {
},
setItem: (key, value) => {
try {
persistRefCount++;
log.trace('Clearing client state');
await api.clear();
} catch {
log.error('Failed to clear client state');
} finally {
persistRefCount--;
lastPersistedState.clear();
if (persistRefCount < 0) {
log.trace('Persist ref count is negative, resetting to 0');
persistRefCount = 0;
}
return set(key, value, $idbKeyValStore.get());
} catch (originalError) {
throw new StorageError({
key,
value,
projectId: $projectId.get(),
originalError,
});
}
};
const registerListeners = () => {
const onBeforeUnload = (e: BeforeUnloadEvent) => {
if (persistRefCount > 0) {
e.preventDefault();
}
};
window.addEventListener('beforeunload', onBeforeUnload);
return () => {
window.removeEventListener('beforeunload', onBeforeUnload);
};
};
return { reduxRememberDriver, clearStorage, registerListeners };
};
export const buildStorageApi = (api?: {
getItem: (key: string) => Promise<any>;
setItem: (key: string, value: any) => Promise<any>;
clear: () => Promise<void>;
}) => {
if (api) {
return buildCustomDriver(api);
} else {
return buildOSSServerBackedDriver();
}
},
};

View File

@@ -0,0 +1,73 @@
import type { TypedStartListening } from '@reduxjs/toolkit';
import { addListener, createListenerMiddleware } from '@reduxjs/toolkit';
import { addAdHocPostProcessingRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/addAdHocPostProcessingRequestedListener';
import { addAnyEnqueuedListener } from 'app/store/middleware/listenerMiddleware/listeners/anyEnqueued';
import { addAppConfigReceivedListener } from 'app/store/middleware/listenerMiddleware/listeners/appConfigReceived';
import { addAppStartedListener } from 'app/store/middleware/listenerMiddleware/listeners/appStarted';
import { addBatchEnqueuedListener } from 'app/store/middleware/listenerMiddleware/listeners/batchEnqueued';
import { addDeleteBoardAndImagesFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/boardAndImagesDeleted';
import { addBoardIdSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/boardIdSelected';
import { addBulkDownloadListeners } from 'app/store/middleware/listenerMiddleware/listeners/bulkDownload';
import { addGetOpenAPISchemaListener } from 'app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema';
import { addImageAddedToBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageAddedToBoard';
import { addImageRemovedFromBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageRemovedFromBoard';
import { addImageUploadedFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageUploaded';
import { addModelSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelSelected';
import { addModelsLoadedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelsLoaded';
import { addSetDefaultSettingsListener } from 'app/store/middleware/listenerMiddleware/listeners/setDefaultSettings';
import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketConnected';
import type { AppDispatch, RootState } from 'app/store/store';
import { addArchivedOrDeletedBoardListener } from './listeners/addArchivedOrDeletedBoardListener';
export const listenerMiddleware = createListenerMiddleware();
export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
const startAppListening = listenerMiddleware.startListening as AppStartListening;
export const addAppListener = addListener.withTypes<RootState, AppDispatch>();
/**
* The RTK listener middleware is a lightweight alternative sagas/observables.
*
* Most side effect logic should live in a listener.
*/
// Image uploaded
addImageUploadedFulfilledListener(startAppListening);
// Image deleted
addDeleteBoardAndImagesFulfilledListener(startAppListening);
// User Invoked
addAnyEnqueuedListener(startAppListening);
addBatchEnqueuedListener(startAppListening);
// Socket.IO
addSocketConnectedEventListener(startAppListening);
// Gallery bulk download
addBulkDownloadListeners(startAppListening);
// Boards
addImageAddedToBoardFulfilledListener(startAppListening);
addImageRemovedFromBoardFulfilledListener(startAppListening);
addBoardIdSelectedListener(startAppListening);
addArchivedOrDeletedBoardListener(startAppListening);
// Node schemas
addGetOpenAPISchemaListener(startAppListening);
// Models
addModelSelectedListener(startAppListening);
// app startup
addAppStartedListener(startAppListening);
addModelsLoadedListener(startAppListening);
addAppConfigReceivedListener(startAppListening);
// Ad-hoc upscale workflwo
addAdHocPostProcessingRequestedListener(startAppListening);
addSetDefaultSettingsListener(startAppListening);

View File

@@ -1,6 +1,6 @@
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/store';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { buildAdHocPostProcessingGraph } from 'features/nodes/util/graph/buildAdHocPostProcessingGraph';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';

View File

@@ -1,5 +1,5 @@
import { isAnyOf } from '@reduxjs/toolkit';
import type { AppStartListening } from 'app/store/store';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
import {
autoAddBoardIdChanged,

View File

@@ -1,4 +1,4 @@
import type { AppStartListening } from 'app/store/store';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { queueApi, selectQueueStatus } from 'services/api/endpoints/queue';
export const addAnyEnqueuedListener = (startAppListening: AppStartListening) => {

View File

@@ -1,4 +1,4 @@
import type { AppStartListening } from 'app/store/store';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { setInfillMethod } from 'features/controlLayers/store/paramsSlice';
import { shouldUseNSFWCheckerChanged, shouldUseWatermarkerChanged } from 'features/system/store/systemSlice';
import { appInfoApi } from 'services/api/endpoints/appInfo';

View File

@@ -1,5 +1,5 @@
import { createAction } from '@reduxjs/toolkit';
import type { AppStartListening } from 'app/store/store';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { imagesApi } from 'services/api/endpoints/images';

View File

@@ -1,5 +1,5 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/store';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { truncate } from 'es-toolkit/compat';
import { zPydanticValidationError } from 'features/system/store/zodSchemas';
import { toast } from 'features/toast/toast';

View File

@@ -1,4 +1,4 @@
import type { AppStartListening } from 'app/store/store';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { getImageUsage } from 'features/deleteImageModal/store/state';

View File

@@ -1,5 +1,5 @@
import { isAnyOf } from '@reduxjs/toolkit';
import type { AppStartListening } from 'app/store/store';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { selectGetImageNamesQueryArgs, selectSelectedBoardId } from 'features/gallery/store/gallerySelectors';
import { boardIdSelected, galleryViewChanged, imageSelected } from 'features/gallery/store/gallerySlice';
import { imagesApi } from 'services/api/endpoints/images';

View File

@@ -1,5 +1,5 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/store';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';

View File

@@ -1,5 +1,5 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/store';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { parseify } from 'common/util/serialize';
import { size } from 'es-toolkit/compat';
import { $templates } from 'features/nodes/store/nodesSlice';

View File

@@ -1,5 +1,5 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/store';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { imagesApi } from 'services/api/endpoints/images';
const log = logger('gallery');

View File

@@ -1,5 +1,5 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/store';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { imagesApi } from 'services/api/endpoints/images';
const log = logger('gallery');

View File

@@ -1,6 +1,7 @@
import { isAnyOf } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening, RootState } from 'app/store/store';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { RootState } from 'app/store/store';
import { omit } from 'es-toolkit/compat';
import { imageUploadedClientSide } from 'features/gallery/store/actions';
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';

View File

@@ -1,5 +1,5 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/store';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { bboxSyncedToOptimalDimension, rgRefImageModelChanged } from 'features/controlLayers/store/canvasSlice';
import { buildSelectIsStaging, selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { loraDeleted } from 'features/controlLayers/store/lorasSlice';

View File

@@ -1,5 +1,6 @@
import { logger } from 'app/logging/logger';
import type { AppDispatch, AppStartListening, RootState } from 'app/store/store';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppDispatch, RootState } from 'app/store/store';
import { controlLayerModelChanged, rgRefImageModelChanged } from 'features/controlLayers/store/canvasSlice';
import { loraDeleted } from 'features/controlLayers/store/lorasSlice';
import {

View File

@@ -1,4 +1,4 @@
import type { AppStartListening } from 'app/store/store';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { isNil } from 'es-toolkit';
import { bboxHeightChanged, bboxWidthChanged } from 'features/controlLayers/store/canvasSlice';
import { buildSelectIsStaging, selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';

View File

@@ -1,8 +1,8 @@
import { objectEquals } from '@observ33r/object-equals';
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { $baseUrl } from 'app/store/nanostores/baseUrl';
import type { AppStartListening } from 'app/store/store';
import { atom } from 'nanostores';
import { api } from 'services/api';
import { modelsApi } from 'services/api/endpoints/models';

View File

@@ -1,46 +1,35 @@
import type { ThunkDispatch, TypedStartListening, UnknownAction } from '@reduxjs/toolkit';
import { addListener, combineReducers, configureStore, createListenerMiddleware } from '@reduxjs/toolkit';
import type { ThunkDispatch, UnknownAction } from '@reduxjs/toolkit';
import { autoBatchEnhancer, combineReducers, configureStore } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import { idbKeyValDriver } from 'app/store/enhancers/reduxRemember/driver';
import { errorHandler } from 'app/store/enhancers/reduxRemember/errors';
import { addAdHocPostProcessingRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/addAdHocPostProcessingRequestedListener';
import { addAnyEnqueuedListener } from 'app/store/middleware/listenerMiddleware/listeners/anyEnqueued';
import { addAppConfigReceivedListener } from 'app/store/middleware/listenerMiddleware/listeners/appConfigReceived';
import { addAppStartedListener } from 'app/store/middleware/listenerMiddleware/listeners/appStarted';
import { addBatchEnqueuedListener } from 'app/store/middleware/listenerMiddleware/listeners/batchEnqueued';
import { addDeleteBoardAndImagesFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/boardAndImagesDeleted';
import { addBoardIdSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/boardIdSelected';
import { addBulkDownloadListeners } from 'app/store/middleware/listenerMiddleware/listeners/bulkDownload';
import { addGetOpenAPISchemaListener } from 'app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema';
import { addImageAddedToBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageAddedToBoard';
import { addImageRemovedFromBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageRemovedFromBoard';
import { addModelSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelSelected';
import { addModelsLoadedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelsLoaded';
import { addSetDefaultSettingsListener } from 'app/store/middleware/listenerMiddleware/listeners/setDefaultSettings';
import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketConnected';
import { deepClone } from 'common/util/deepClone';
import { keys, mergeWith, omit, pick } from 'es-toolkit/compat';
import { changeBoardModalSliceConfig } from 'features/changeBoardModal/store/slice';
import { canvasSettingsSliceConfig } from 'features/controlLayers/store/canvasSettingsSlice';
import { canvasSliceConfig } from 'features/controlLayers/store/canvasSlice';
import { canvasSessionSliceConfig } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { lorasSliceConfig } from 'features/controlLayers/store/lorasSlice';
import { paramsSliceConfig } from 'features/controlLayers/store/paramsSlice';
import { refImagesSliceConfig } from 'features/controlLayers/store/refImagesSlice';
import { dynamicPromptsSliceConfig } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import { gallerySliceConfig } from 'features/gallery/store/gallerySlice';
import { modelManagerSliceConfig } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { nodesSliceConfig } from 'features/nodes/store/nodesSlice';
import { workflowLibrarySliceConfig } from 'features/nodes/store/workflowLibrarySlice';
import { workflowSettingsSliceConfig } from 'features/nodes/store/workflowSettingsSlice';
import { upscaleSliceConfig } from 'features/parameters/store/upscaleSlice';
import { queueSliceConfig } from 'features/queue/store/queueSlice';
import { stylePresetSliceConfig } from 'features/stylePresets/store/stylePresetSlice';
import { configSliceConfig } from 'features/system/store/configSlice';
import { systemSliceConfig } from 'features/system/store/systemSlice';
import { uiSliceConfig } from 'features/ui/store/uiSlice';
import { changeBoardModalSlice } from 'features/changeBoardModal/store/slice';
import { canvasSettingsPersistConfig, canvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
import { canvasPersistConfig, canvasSlice, canvasUndoableConfig } from 'features/controlLayers/store/canvasSlice';
import {
canvasSessionSlice,
canvasStagingAreaPersistConfig,
} from 'features/controlLayers/store/canvasStagingAreaSlice';
import { lorasPersistConfig, lorasSlice } from 'features/controlLayers/store/lorasSlice';
import { paramsPersistConfig, paramsSlice } from 'features/controlLayers/store/paramsSlice';
import { refImagesPersistConfig, refImagesSlice } from 'features/controlLayers/store/refImagesSlice';
import { dynamicPromptsPersistConfig, dynamicPromptsSlice } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import { galleryPersistConfig, gallerySlice } from 'features/gallery/store/gallerySlice';
import { modelManagerV2PersistConfig, modelManagerV2Slice } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { nodesPersistConfig, nodesSlice, nodesUndoableConfig } from 'features/nodes/store/nodesSlice';
import { workflowLibraryPersistConfig, workflowLibrarySlice } from 'features/nodes/store/workflowLibrarySlice';
import { workflowSettingsPersistConfig, workflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
import { upscalePersistConfig, upscaleSlice } from 'features/parameters/store/upscaleSlice';
import { queueSlice } from 'features/queue/store/queueSlice';
import { stylePresetPersistConfig, stylePresetSlice } from 'features/stylePresets/store/stylePresetSlice';
import { configSlice } from 'features/system/store/configSlice';
import { systemPersistConfig, systemSlice } from 'features/system/store/systemSlice';
import { uiPersistConfig, uiSlice } from 'features/ui/store/uiSlice';
import { diff } from 'jsondiffpatch';
import dynamicMiddlewares from 'redux-dynamic-middlewares';
import type { Driver, SerializeFunction, UnserializeFunction } from 'redux-remember';
import type { SerializeFunction, UnserializeFunction } from 'redux-remember';
import { rememberEnhancer, rememberReducer } from 'redux-remember';
import undoable, { newHistory } from 'redux-undo';
import { serializeError } from 'serialize-error';
@@ -48,116 +37,123 @@ import { api } from 'services/api';
import { authToastMiddleware } from 'services/api/authToastMiddleware';
import type { JsonObject } from 'type-fest';
import { STORAGE_PREFIX } from './constants';
import { actionSanitizer } from './middleware/devtools/actionSanitizer';
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
import { addArchivedOrDeletedBoardListener } from './middleware/listenerMiddleware/listeners/addArchivedOrDeletedBoardListener';
import { addImageUploadedFulfilledListener } from './middleware/listenerMiddleware/listeners/imageUploaded';
export const listenerMiddleware = createListenerMiddleware();
import { listenerMiddleware } from './middleware/listenerMiddleware';
const log = logger('system');
// When adding a slice, add the config to the SLICE_CONFIGS object below, then add the reducer to ALL_REDUCERS.
const SLICE_CONFIGS = {
[canvasSessionSliceConfig.slice.reducerPath]: canvasSessionSliceConfig,
[canvasSettingsSliceConfig.slice.reducerPath]: canvasSettingsSliceConfig,
[canvasSliceConfig.slice.reducerPath]: canvasSliceConfig,
[changeBoardModalSliceConfig.slice.reducerPath]: changeBoardModalSliceConfig,
[configSliceConfig.slice.reducerPath]: configSliceConfig,
[dynamicPromptsSliceConfig.slice.reducerPath]: dynamicPromptsSliceConfig,
[gallerySliceConfig.slice.reducerPath]: gallerySliceConfig,
[lorasSliceConfig.slice.reducerPath]: lorasSliceConfig,
[modelManagerSliceConfig.slice.reducerPath]: modelManagerSliceConfig,
[nodesSliceConfig.slice.reducerPath]: nodesSliceConfig,
[paramsSliceConfig.slice.reducerPath]: paramsSliceConfig,
[queueSliceConfig.slice.reducerPath]: queueSliceConfig,
[refImagesSliceConfig.slice.reducerPath]: refImagesSliceConfig,
[stylePresetSliceConfig.slice.reducerPath]: stylePresetSliceConfig,
[systemSliceConfig.slice.reducerPath]: systemSliceConfig,
[uiSliceConfig.slice.reducerPath]: uiSliceConfig,
[upscaleSliceConfig.slice.reducerPath]: upscaleSliceConfig,
[workflowLibrarySliceConfig.slice.reducerPath]: workflowLibrarySliceConfig,
[workflowSettingsSliceConfig.slice.reducerPath]: workflowSettingsSliceConfig,
};
// TS makes it really hard to dynamically create this object :/ so it's just hardcoded here.
// Remember to wrap undoable reducers in `undoable()`!
const ALL_REDUCERS = {
const allReducers = {
[api.reducerPath]: api.reducer,
[canvasSessionSliceConfig.slice.reducerPath]: canvasSessionSliceConfig.slice.reducer,
[canvasSettingsSliceConfig.slice.reducerPath]: canvasSettingsSliceConfig.slice.reducer,
// Undoable!
[canvasSliceConfig.slice.reducerPath]: undoable(
canvasSliceConfig.slice.reducer,
canvasSliceConfig.undoableConfig?.reduxUndoOptions
),
[changeBoardModalSliceConfig.slice.reducerPath]: changeBoardModalSliceConfig.slice.reducer,
[configSliceConfig.slice.reducerPath]: configSliceConfig.slice.reducer,
[dynamicPromptsSliceConfig.slice.reducerPath]: dynamicPromptsSliceConfig.slice.reducer,
[gallerySliceConfig.slice.reducerPath]: gallerySliceConfig.slice.reducer,
[lorasSliceConfig.slice.reducerPath]: lorasSliceConfig.slice.reducer,
[modelManagerSliceConfig.slice.reducerPath]: modelManagerSliceConfig.slice.reducer,
// Undoable!
[nodesSliceConfig.slice.reducerPath]: undoable(
nodesSliceConfig.slice.reducer,
nodesSliceConfig.undoableConfig?.reduxUndoOptions
),
[paramsSliceConfig.slice.reducerPath]: paramsSliceConfig.slice.reducer,
[queueSliceConfig.slice.reducerPath]: queueSliceConfig.slice.reducer,
[refImagesSliceConfig.slice.reducerPath]: refImagesSliceConfig.slice.reducer,
[stylePresetSliceConfig.slice.reducerPath]: stylePresetSliceConfig.slice.reducer,
[systemSliceConfig.slice.reducerPath]: systemSliceConfig.slice.reducer,
[uiSliceConfig.slice.reducerPath]: uiSliceConfig.slice.reducer,
[upscaleSliceConfig.slice.reducerPath]: upscaleSliceConfig.slice.reducer,
[workflowLibrarySliceConfig.slice.reducerPath]: workflowLibrarySliceConfig.slice.reducer,
[workflowSettingsSliceConfig.slice.reducerPath]: workflowSettingsSliceConfig.slice.reducer,
[gallerySlice.name]: gallerySlice.reducer,
[nodesSlice.name]: undoable(nodesSlice.reducer, nodesUndoableConfig),
[systemSlice.name]: systemSlice.reducer,
[configSlice.name]: configSlice.reducer,
[uiSlice.name]: uiSlice.reducer,
[dynamicPromptsSlice.name]: dynamicPromptsSlice.reducer,
[changeBoardModalSlice.name]: changeBoardModalSlice.reducer,
[modelManagerV2Slice.name]: modelManagerV2Slice.reducer,
[queueSlice.name]: queueSlice.reducer,
[canvasSlice.name]: undoable(canvasSlice.reducer, canvasUndoableConfig),
[workflowSettingsSlice.name]: workflowSettingsSlice.reducer,
[upscaleSlice.name]: upscaleSlice.reducer,
[stylePresetSlice.name]: stylePresetSlice.reducer,
[paramsSlice.name]: paramsSlice.reducer,
[canvasSettingsSlice.name]: canvasSettingsSlice.reducer,
[canvasSessionSlice.name]: canvasSessionSlice.reducer,
[lorasSlice.name]: lorasSlice.reducer,
[workflowLibrarySlice.name]: workflowLibrarySlice.reducer,
[refImagesSlice.name]: refImagesSlice.reducer,
};
const rootReducer = combineReducers(ALL_REDUCERS);
const rootReducer = combineReducers(allReducers);
const rememberedRootReducer = rememberReducer(rootReducer);
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
export type PersistConfig<T = any> = {
/**
* The name of the slice.
*/
name: keyof typeof allReducers;
/**
* The initial state of the slice.
*/
initialState: T;
/**
* Migrate the state to the current version during rehydration.
* @param state The rehydrated state.
* @returns A correctly-shaped state.
*/
migrate: (state: unknown) => T;
/**
* Keys to omit from the persisted state.
*/
persistDenylist: (keyof T)[];
};
const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = {
[galleryPersistConfig.name]: galleryPersistConfig,
[nodesPersistConfig.name]: nodesPersistConfig,
[systemPersistConfig.name]: systemPersistConfig,
[uiPersistConfig.name]: uiPersistConfig,
[dynamicPromptsPersistConfig.name]: dynamicPromptsPersistConfig,
[modelManagerV2PersistConfig.name]: modelManagerV2PersistConfig,
[canvasPersistConfig.name]: canvasPersistConfig,
[workflowSettingsPersistConfig.name]: workflowSettingsPersistConfig,
[upscalePersistConfig.name]: upscalePersistConfig,
[stylePresetPersistConfig.name]: stylePresetPersistConfig,
[paramsPersistConfig.name]: paramsPersistConfig,
[canvasSettingsPersistConfig.name]: canvasSettingsPersistConfig,
[canvasStagingAreaPersistConfig.name]: canvasStagingAreaPersistConfig,
[lorasPersistConfig.name]: lorasPersistConfig,
[workflowLibraryPersistConfig.name]: workflowLibraryPersistConfig,
[refImagesSlice.name]: refImagesPersistConfig,
};
const unserialize: UnserializeFunction = (data, key) => {
const sliceConfig = SLICE_CONFIGS[key as keyof typeof SLICE_CONFIGS];
if (!sliceConfig?.persistConfig) {
const persistConfig = persistConfigs[key as keyof typeof persistConfigs];
if (!persistConfig) {
throw new Error(`No persist config for slice "${key}"`);
}
const { getInitialState, persistConfig, undoableConfig } = sliceConfig;
let state;
try {
const initialState = getInitialState();
const { initialState, migrate } = persistConfig;
const parsed = JSON.parse(data);
// strip out old keys
const stripped = pick(deepClone(data), keys(initialState));
const stripped = pick(deepClone(parsed), keys(initialState));
// run (additive) migrations
const migrated = migrate(stripped);
/*
* Merge in initial state as default values, covering any missing keys. You might be tempted to use _.defaultsDeep,
* but that merges arrays by index and partial objects by key. Using an identity function as the customizer results
* in behaviour like defaultsDeep, but doesn't overwrite any values that are not undefined in the migrated state.
*/
const unPersistDenylisted = mergeWith(stripped, initialState, (objVal) => objVal);
// run (additive) migrations
const migrated = persistConfig.migrate(unPersistDenylisted);
const transformed = mergeWith(migrated, initialState, (objVal) => objVal);
log.debug(
{
persistedData: data as JsonObject,
rehydratedData: migrated as JsonObject,
diff: diff(data, migrated) as JsonObject,
persistedData: parsed,
rehydratedData: transformed,
diff: diff(parsed, transformed) as JsonObject, // this is always serializable
},
`Rehydrated slice "${key}"`
);
state = migrated;
state = transformed;
} catch (err) {
log.warn(
{ error: serializeError(err as Error) },
`Error rehydrating slice "${key}", falling back to default initial state`
);
state = getInitialState();
state = persistConfig.initialState;
}
// Undoable slices must be wrapped in a history!
if (undoableConfig) {
// If the slice is undoable, we need to wrap it in a new history - only nodes and canvas are undoable at the moment.
// TODO(psyche): make this automatic & remove the hard-coding for specific slices.
if (key === nodesSlice.name || key === canvasSlice.name) {
return newHistory([], state, []);
} else {
return state;
@@ -165,30 +161,21 @@ const unserialize: UnserializeFunction = (data, key) => {
};
const serialize: SerializeFunction = (data, key) => {
const sliceConfig = SLICE_CONFIGS[key as keyof typeof SLICE_CONFIGS];
if (!sliceConfig?.persistConfig) {
const persistConfig = persistConfigs[key as keyof typeof persistConfigs];
if (!persistConfig) {
throw new Error(`No persist config for slice "${key}"`);
}
const result = omit(
sliceConfig.undoableConfig ? data.present : data,
sliceConfig.persistConfig.persistDenylist ?? []
);
// Heuristic to determine if the slice is undoable - could just hardcode it in the persistConfig
const isUndoable = 'present' in data && 'past' in data && 'future' in data && '_latestUnfiltered' in data;
const result = omit(isUndoable ? data.present : data, persistConfig.persistDenylist);
return JSON.stringify(result);
};
const PERSISTED_KEYS = Object.values(SLICE_CONFIGS)
.filter((sliceConfig) => !!sliceConfig.persistConfig)
.map((sliceConfig) => sliceConfig.slice.reducerPath);
export const createStore = (reduxRememberOptions: { driver: Driver; persistThrottle: number }) =>
export const createStore = (uniqueStoreKey?: string, persist = true) =>
configureStore({
reducer: rememberedRootReducer,
middleware: (getDefaultMiddleware) =>
getDefaultMiddleware({
// serializableCheck: false,
// immutableCheck: false,
serializableCheck: import.meta.env.MODE === 'development',
immutableCheck: import.meta.env.MODE === 'development',
})
@@ -198,16 +185,19 @@ export const createStore = (reduxRememberOptions: { driver: Driver; persistThrot
// .concat(getDebugLoggerMiddleware())
.prepend(listenerMiddleware.middleware),
enhancers: (getDefaultEnhancers) => {
const enhancers = getDefaultEnhancers();
return enhancers.prepend(
rememberEnhancer(reduxRememberOptions.driver, PERSISTED_KEYS, {
persistThrottle: reduxRememberOptions.persistThrottle,
serialize,
unserialize,
prefix: '',
errorHandler,
})
);
const _enhancers = getDefaultEnhancers().concat(autoBatchEnhancer());
if (persist) {
_enhancers.push(
rememberEnhancer(idbKeyValDriver, keys(persistConfigs), {
persistDebounce: 300,
serialize,
unserialize,
prefix: uniqueStoreKey ? `${STORAGE_PREFIX}${uniqueStoreKey}-` : STORAGE_PREFIX,
errorHandler,
})
);
}
return _enhancers;
},
devTools: {
actionSanitizer,
@@ -224,48 +214,7 @@ export const createStore = (reduxRememberOptions: { driver: Driver; persistThrot
export type AppStore = ReturnType<typeof createStore>;
export type RootState = ReturnType<AppStore['getState']>;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export type AppThunkDispatch = ThunkDispatch<RootState, any, UnknownAction>;
export type AppDispatch = ReturnType<typeof createStore>['dispatch'];
export type AppGetState = ReturnType<typeof createStore>['getState'];
export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
export const addAppListener = addListener.withTypes<RootState, AppDispatch>();
const startAppListening = listenerMiddleware.startListening as AppStartListening;
addImageUploadedFulfilledListener(startAppListening);
// Image deleted
addDeleteBoardAndImagesFulfilledListener(startAppListening);
// User Invoked
addAnyEnqueuedListener(startAppListening);
addBatchEnqueuedListener(startAppListening);
// Socket.IO
addSocketConnectedEventListener(startAppListening);
// Gallery bulk download
addBulkDownloadListeners(startAppListening);
// Boards
addImageAddedToBoardFulfilledListener(startAppListening);
addImageRemovedFromBoardFulfilledListener(startAppListening);
addBoardIdSelectedListener(startAppListening);
addArchivedOrDeletedBoardListener(startAppListening);
// Node schemas
addGetOpenAPISchemaListener(startAppListening);
// Models
addModelSelectedListener(startAppListening);
// app startup
addAppStartedListener(startAppListening);
addModelsLoadedListener(startAppListening);
addAppConfigReceivedListener(startAppListening);
// Ad-hoc upscale workflwo
addAdHocPostProcessingRequestedListener(startAppListening);
addSetDefaultSettingsListener(startAppListening);

View File

@@ -1,46 +0,0 @@
import type { Slice } from '@reduxjs/toolkit';
import type { UndoableOptions } from 'redux-undo';
import type { ZodType } from 'zod';
type StateFromSlice<T extends Slice> = T extends Slice<infer U> ? U : never;
export type SliceConfig<T extends Slice> = {
/**
* The redux slice (return of createSlice).
*/
slice: T;
/**
* The zod schema for the slice.
*/
schema: ZodType<StateFromSlice<T>>;
/**
* A function that returns the initial state of the slice.
*/
getInitialState: () => StateFromSlice<T>;
/**
* The optional persist configuration for this slice. If omitted, the slice will not be persisted.
*/
persistConfig?: {
/**
* Migrate the state to the current version during rehydration. This method should throw an error if the migration
* fails.
*
* @param state The rehydrated state.
* @returns A correctly-shaped state.
*/
migrate: (state: unknown) => StateFromSlice<T>;
/**
* Keys to omit from the persisted state.
*/
persistDenylist?: (keyof StateFromSlice<T>)[];
};
/**
* The optional undoable configuration for this slice. If omitted, the slice will not be undoable.
*/
undoableConfig?: {
/**
* The options to be passed into redux-undo.
*/
reduxUndoOptions: UndoableOptions<StateFromSlice<T>>;
};
};

View File

@@ -1,299 +1,130 @@
import { zFilterType } from 'features/controlLayers/store/filters';
import { zParameterPrecision, zParameterScheduler } from 'features/parameters/types/parameterSchemas';
import { zTabName } from 'features/ui/store/uiTypes';
import type { FilterType } from 'features/controlLayers/store/filters';
import type { ParameterPrecision, ParameterScheduler } from 'features/parameters/types/parameterSchemas';
import type { TabName } from 'features/ui/store/uiTypes';
import type { PartialDeep } from 'type-fest';
import z from 'zod';
const zAppFeature = z.enum([
'faceRestore',
'upscaling',
'lightbox',
'modelManager',
'githubLink',
'discordLink',
'bugLink',
'aboutModal',
'localization',
'consoleLogging',
'dynamicPrompting',
'batches',
'syncModels',
'multiselect',
'pauseQueue',
'resumeQueue',
'invocationCache',
'modelCache',
'bulkDownload',
'starterModels',
'hfToken',
'retryQueueItem',
'cancelAndClearAll',
'chatGPT4oHigh',
'modelRelationships',
]);
export type AppFeature = z.infer<typeof zAppFeature>;
/**
* A disable-able application feature
*/
export type AppFeature =
| 'faceRestore'
| 'upscaling'
| 'lightbox'
| 'modelManager'
| 'githubLink'
| 'discordLink'
| 'bugLink'
| 'aboutModal'
| 'localization'
| 'consoleLogging'
| 'dynamicPrompting'
| 'batches'
| 'syncModels'
| 'multiselect'
| 'pauseQueue'
| 'resumeQueue'
| 'invocationCache'
| 'modelCache'
| 'bulkDownload'
| 'starterModels'
| 'hfToken'
| 'retryQueueItem'
| 'cancelAndClearAll'
| 'chatGPT4oHigh'
| 'modelRelationships';
/**
* A disable-able Stable Diffusion feature
*/
export type SDFeature =
| 'controlNet'
| 'noise'
| 'perlinNoise'
| 'noiseThreshold'
| 'variation'
| 'symmetry'
| 'seamless'
| 'hires'
| 'lora'
| 'embedding'
| 'vae'
| 'hrf';
const zSDFeature = z.enum([
'controlNet',
'noise',
'perlinNoise',
'noiseThreshold',
'variation',
'symmetry',
'seamless',
'hires',
'lora',
'embedding',
'vae',
'hrf',
]);
export type SDFeature = z.infer<typeof zSDFeature>;
const zNumericalParameterConfig = z.object({
initial: z.number().default(512),
sliderMin: z.number().default(64),
sliderMax: z.number().default(1536),
numberInputMin: z.number().default(64),
numberInputMax: z.number().default(4096),
fineStep: z.number().default(8),
coarseStep: z.number().default(64),
});
export type NumericalParameterConfig = {
initial: number;
sliderMin: number;
sliderMax: number;
numberInputMin: number;
numberInputMax: number;
fineStep: number;
coarseStep: number;
};
/**
* Configuration options for the InvokeAI UI.
* Distinct from system settings which may be changed inside the app.
*/
export const zAppConfig = z.object({
export type AppConfig = {
/**
* Whether or not we should update image urls when image loading errors
*/
shouldUpdateImagesOnConnect: z.boolean(),
shouldFetchMetadataFromApi: z.boolean(),
shouldUpdateImagesOnConnect: boolean;
shouldFetchMetadataFromApi: boolean;
/**
* Sets a size limit for outputs on the upscaling tab. This is a maximum dimension, so the actual max number of pixels
* will be the square of this value.
*/
maxUpscaleDimension: z.number().optional(),
allowPrivateBoards: z.boolean(),
allowPrivateStylePresets: z.boolean(),
allowClientSideUpload: z.boolean(),
allowPublishWorkflows: z.boolean(),
allowPromptExpansion: z.boolean(),
disabledTabs: z.array(zTabName),
disabledFeatures: z.array(zAppFeature),
disabledSDFeatures: z.array(zSDFeature),
nodesAllowlist: z.array(z.string()).optional(),
nodesDenylist: z.array(z.string()).optional(),
metadataFetchDebounce: z.number().int().optional(),
workflowFetchDebounce: z.number().int().optional(),
isLocal: z.boolean().optional(),
shouldShowCredits: z.boolean().optional(),
sd: z.object({
defaultModel: z.string().optional(),
disabledControlNetModels: z.array(z.string()),
disabledControlNetProcessors: z.array(zFilterType),
// Core parameters
iterations: zNumericalParameterConfig,
width: zNumericalParameterConfig,
height: zNumericalParameterConfig,
steps: zNumericalParameterConfig,
guidance: zNumericalParameterConfig,
cfgRescaleMultiplier: zNumericalParameterConfig,
img2imgStrength: zNumericalParameterConfig,
scheduler: zParameterScheduler.optional(),
vaePrecision: zParameterPrecision.optional(),
// Canvas
boundingBoxHeight: zNumericalParameterConfig,
boundingBoxWidth: zNumericalParameterConfig,
scaledBoundingBoxHeight: zNumericalParameterConfig,
scaledBoundingBoxWidth: zNumericalParameterConfig,
canvasCoherenceStrength: zNumericalParameterConfig,
canvasCoherenceEdgeSize: zNumericalParameterConfig,
infillTileSize: zNumericalParameterConfig,
infillPatchmatchDownscaleSize: zNumericalParameterConfig,
// Misc advanced
clipSkip: zNumericalParameterConfig, // slider and input max are ignored for this, because the values depend on the model
maskBlur: zNumericalParameterConfig,
hrfStrength: zNumericalParameterConfig,
dynamicPrompts: z.object({
maxPrompts: zNumericalParameterConfig,
}),
ca: z.object({
weight: zNumericalParameterConfig,
}),
}),
flux: z.object({
guidance: zNumericalParameterConfig,
}),
});
export type AppConfig = z.infer<typeof zAppConfig>;
export type PartialAppConfig = PartialDeep<AppConfig>;
export const getDefaultAppConfig = (): AppConfig => ({
isLocal: true,
shouldUpdateImagesOnConnect: false,
shouldFetchMetadataFromApi: false,
allowPrivateBoards: false,
allowPrivateStylePresets: false,
allowClientSideUpload: false,
allowPublishWorkflows: false,
allowPromptExpansion: false,
shouldShowCredits: false,
disabledTabs: [],
disabledFeatures: ['lightbox', 'faceRestore', 'batches'] satisfies AppFeature[],
disabledSDFeatures: ['variation', 'symmetry', 'hires', 'perlinNoise', 'noiseThreshold'] satisfies SDFeature[],
maxUpscaleDimension?: number;
allowPrivateBoards: boolean;
allowPrivateStylePresets: boolean;
allowClientSideUpload: boolean;
allowPublishWorkflows: boolean;
allowPromptExpansion: boolean;
disabledTabs: TabName[];
disabledFeatures: AppFeature[];
disabledSDFeatures: SDFeature[];
nodesAllowlist: string[] | undefined;
nodesDenylist: string[] | undefined;
metadataFetchDebounce?: number;
workflowFetchDebounce?: number;
isLocal?: boolean;
shouldShowCredits: boolean;
sd: {
disabledControlNetModels: [],
disabledControlNetProcessors: [],
iterations: {
initial: 1,
sliderMin: 1,
sliderMax: 1000,
numberInputMin: 1,
numberInputMax: 10000,
fineStep: 1,
coarseStep: 1,
},
width: zNumericalParameterConfig.parse({}), // initial value comes from model
height: zNumericalParameterConfig.parse({}), // initial value comes from model
boundingBoxWidth: zNumericalParameterConfig.parse({}), // initial value comes from model
boundingBoxHeight: zNumericalParameterConfig.parse({}), // initial value comes from model
scaledBoundingBoxWidth: zNumericalParameterConfig.parse({}), // initial value comes from model
scaledBoundingBoxHeight: zNumericalParameterConfig.parse({}), // initial value comes from model
scheduler: 'dpmpp_3m_k' as const,
vaePrecision: 'fp32' as const,
steps: {
initial: 30,
sliderMin: 1,
sliderMax: 100,
numberInputMin: 1,
numberInputMax: 500,
fineStep: 1,
coarseStep: 1,
},
guidance: {
initial: 7,
sliderMin: 1,
sliderMax: 20,
numberInputMin: 1,
numberInputMax: 200,
fineStep: 0.1,
coarseStep: 0.5,
},
img2imgStrength: {
initial: 0.7,
sliderMin: 0,
sliderMax: 1,
numberInputMin: 0,
numberInputMax: 1,
fineStep: 0.01,
coarseStep: 0.05,
},
canvasCoherenceStrength: {
initial: 0.3,
sliderMin: 0,
sliderMax: 1,
numberInputMin: 0,
numberInputMax: 1,
fineStep: 0.01,
coarseStep: 0.05,
},
hrfStrength: {
initial: 0.45,
sliderMin: 0,
sliderMax: 1,
numberInputMin: 0,
numberInputMax: 1,
fineStep: 0.01,
coarseStep: 0.05,
},
canvasCoherenceEdgeSize: {
initial: 16,
sliderMin: 0,
sliderMax: 128,
numberInputMin: 0,
numberInputMax: 1024,
fineStep: 8,
coarseStep: 16,
},
cfgRescaleMultiplier: {
initial: 0,
sliderMin: 0,
sliderMax: 0.99,
numberInputMin: 0,
numberInputMax: 0.99,
fineStep: 0.05,
coarseStep: 0.1,
},
clipSkip: {
initial: 0,
sliderMin: 0,
sliderMax: 12, // determined by model selection, unused in practice
numberInputMin: 0,
numberInputMax: 12, // determined by model selection, unused in practice
fineStep: 1,
coarseStep: 1,
},
infillPatchmatchDownscaleSize: {
initial: 1,
sliderMin: 1,
sliderMax: 10,
numberInputMin: 1,
numberInputMax: 10,
fineStep: 1,
coarseStep: 1,
},
infillTileSize: {
initial: 32,
sliderMin: 16,
sliderMax: 64,
numberInputMin: 16,
numberInputMax: 256,
fineStep: 1,
coarseStep: 1,
},
maskBlur: {
initial: 16,
sliderMin: 0,
sliderMax: 128,
numberInputMin: 0,
numberInputMax: 512,
fineStep: 1,
coarseStep: 1,
},
ca: {
weight: {
initial: 1,
sliderMin: 0,
sliderMax: 2,
numberInputMin: -1,
numberInputMax: 2,
fineStep: 0.01,
coarseStep: 0.05,
},
},
defaultModel?: string;
disabledControlNetModels: string[];
disabledControlNetProcessors: FilterType[];
// Core parameters
iterations: NumericalParameterConfig;
width: NumericalParameterConfig; // initial value comes from model
height: NumericalParameterConfig; // initial value comes from model
steps: NumericalParameterConfig;
guidance: NumericalParameterConfig;
cfgRescaleMultiplier: NumericalParameterConfig;
img2imgStrength: NumericalParameterConfig;
scheduler?: ParameterScheduler;
vaePrecision?: ParameterPrecision;
// Canvas
boundingBoxHeight: NumericalParameterConfig; // initial value comes from model
boundingBoxWidth: NumericalParameterConfig; // initial value comes from model
scaledBoundingBoxHeight: NumericalParameterConfig; // initial value comes from model
scaledBoundingBoxWidth: NumericalParameterConfig; // initial value comes from model
canvasCoherenceStrength: NumericalParameterConfig;
canvasCoherenceEdgeSize: NumericalParameterConfig;
infillTileSize: NumericalParameterConfig;
infillPatchmatchDownscaleSize: NumericalParameterConfig;
// Misc advanced
clipSkip: NumericalParameterConfig; // slider and input max are ignored for this, because the values depend on the model
maskBlur: NumericalParameterConfig;
hrfStrength: NumericalParameterConfig;
dynamicPrompts: {
maxPrompts: {
initial: 100,
sliderMin: 1,
sliderMax: 1000,
numberInputMin: 1,
numberInputMax: 10000,
fineStep: 1,
coarseStep: 10,
},
},
},
maxPrompts: NumericalParameterConfig;
};
ca: {
weight: NumericalParameterConfig;
};
};
flux: {
guidance: {
initial: 4,
sliderMin: 2,
sliderMax: 6,
numberInputMin: 1,
numberInputMax: 20,
fineStep: 0.1,
coarseStep: 0.5,
},
},
});
guidance: NumericalParameterConfig;
};
};
export type PartialAppConfig = PartialDeep<AppConfig>;

View File

@@ -1,8 +1,6 @@
import { MenuItem } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { canvasReset } from 'features/controlLayers/store/actions';
import { inpaintMaskAdded } from 'features/controlLayers/store/canvasSlice';
import { $canvasManager } from 'features/controlLayers/store/ephemeral';
import { allEntitiesDeleted } from 'features/controlLayers/store/canvasSlice';
import { paramsReset } from 'features/controlLayers/store/paramsSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
@@ -13,9 +11,7 @@ export const SessionMenuItems = memo(() => {
const dispatch = useAppDispatch();
const resetCanvasLayers = useCallback(() => {
dispatch(canvasReset());
dispatch(inpaintMaskAdded({ isSelected: true, isBookmarked: true }));
$canvasManager.get()?.stage.fitBboxToStage();
dispatch(allEntitiesDeleted());
}, [dispatch]);
const resetGenerationSettings = useCallback(() => {
dispatch(paramsReset());

View File

@@ -0,0 +1,11 @@
import { clearIdbKeyValStore } from 'app/store/enhancers/reduxRemember/driver';
import { useCallback } from 'react';
export const useClearStorage = () => {
const clearStorage = useCallback(() => {
clearIdbKeyValStore();
localStorage.clear();
}, []);
return clearStorage;
};

View File

@@ -139,13 +139,4 @@ export const useGlobalHotkeys = () => {
},
dependencies: [getState, deleteImageModalApi],
});
useRegisteredHotkeys({
id: 'toggleViewer',
category: 'viewer',
callback: () => {
navigationApi.toggleViewerPanel();
},
dependencies: [],
});
};

View File

@@ -0,0 +1,6 @@
import type { ChangeBoardModalState } from './types';
export const initialState: ChangeBoardModalState = {
isModalOpen: false,
image_names: [],
};

View File

@@ -1,20 +1,12 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import z from 'zod';
const zChangeBoardModalState = z.object({
isModalOpen: z.boolean().default(false),
image_names: z.array(z.string()).default(() => []),
});
type ChangeBoardModalState = z.infer<typeof zChangeBoardModalState>;
import { initialState } from './initialState';
const getInitialState = (): ChangeBoardModalState => zChangeBoardModalState.parse({});
const slice = createSlice({
export const changeBoardModalSlice = createSlice({
name: 'changeBoardModal',
initialState: getInitialState(),
initialState,
reducers: {
isModalOpenChanged: (state, action: PayloadAction<boolean>) => {
state.isModalOpen = action.payload;
@@ -29,12 +21,6 @@ const slice = createSlice({
},
});
export const { isModalOpenChanged, imagesToChangeSelected, changeBoardReset } = slice.actions;
export const { isModalOpenChanged, imagesToChangeSelected, changeBoardReset } = changeBoardModalSlice.actions;
export const selectChangeBoardModalSlice = (state: RootState) => state.changeBoardModal;
export const changeBoardModalSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zChangeBoardModalState,
getInitialState,
};

View File

@@ -0,0 +1,4 @@
export type ChangeBoardModalState = {
isModalOpen: boolean;
image_names: string[];
};

View File

@@ -165,9 +165,9 @@ export const CanvasEntityGroupList = memo(({ isSelected, type, children, entityI
<Spacer />
</Flex>
{type === 'raster_layer' && <RasterLayerExportPSDButton />}
<CanvasEntityMergeVisibleButton type={type} />
<CanvasEntityTypeIsHiddenToggle type={type} />
{type === 'raster_layer' && <RasterLayerExportPSDButton />}
<CanvasEntityAddOfTypeButton type={type} />
</Flex>
<Collapse in={collapse.isTrue} style={fixTooltipCloseOnScrollStyles}>

View File

@@ -42,7 +42,7 @@ const DEFAULT_CONFIG: CanvasStageModuleConfig = {
SCALE_FACTOR: 0.999,
FIT_LAYERS_TO_STAGE_PADDING_PX: 48,
SCALE_SNAP_POINTS: [0.25, 0.5, 0.75, 1, 1.5, 2, 3, 4, 5],
SCALE_SNAP_TOLERANCE: 0.02,
SCALE_SNAP_TOLERANCE: 0.05,
};
export class CanvasStageModule extends CanvasModuleBase {
@@ -366,22 +366,11 @@ export class CanvasStageModule extends CanvasModuleBase {
if (deltaT > 300) {
dynamicScaleFactor = this.config.SCALE_FACTOR + (1 - this.config.SCALE_FACTOR) / 2;
} else if (deltaT < 300) {
// Ensure dynamic scale factor stays below 1 to maintain zoom-out direction - if it goes over, we could end up
// zooming in the wrong direction with small scroll amounts
const maxScaleFactor = 0.9999;
dynamicScaleFactor = Math.min(
this.config.SCALE_FACTOR + (1 - this.config.SCALE_FACTOR) * (deltaT / 200),
maxScaleFactor
);
dynamicScaleFactor = this.config.SCALE_FACTOR + (1 - this.config.SCALE_FACTOR) * (deltaT / 200);
}
// Update the intended scale based on the last intended scale, creating a continuous zoom feel
// Handle the sign explicitly to prevent direction reversal with small scroll amounts
const scaleFactor =
scrollAmount > 0
? dynamicScaleFactor ** Math.abs(scrollAmount)
: (1 / dynamicScaleFactor) ** Math.abs(scrollAmount);
const newIntendedScale = this._intendedScale * scaleFactor;
const newIntendedScale = this._intendedScale * dynamicScaleFactor ** scrollAmount;
this._intendedScale = this.constrainScale(newIntendedScale);
// Pass control to the snapping logic
@@ -408,9 +397,6 @@ export class CanvasStageModule extends CanvasModuleBase {
// User has scrolled far enough to break the snap
this._activeSnapPoint = null;
this._applyScale(this._intendedScale, center);
} else {
// Reset intended scale to prevent drift while snapped
this._intendedScale = this._activeSnapPoint;
}
// Else, do nothing - we remain snapped at the current scale, creating a "dead zone"
return;

View File

@@ -1,7 +1,7 @@
import { $alt, $ctrl, $meta, $shift } from '@invoke-ai/ui-library';
import type { Selector } from '@reduxjs/toolkit';
import { addAppListener } from 'app/store/middleware/listenerMiddleware';
import type { AppStore, RootState } from 'app/store/store';
import { addAppListener } from 'app/store/store';
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer';
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';

View File

@@ -1,7 +1,6 @@
import type { PayloadAction, Selector } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit';
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import type { PersistConfig, RootState } from 'app/store/store';
import { zRgbaColor } from 'features/controlLayers/store/types';
import { z } from 'zod';
@@ -12,32 +11,32 @@ const zCanvasSettingsState = z.object({
/**
* Whether to show HUD (Heads-Up Display) on the canvas.
*/
showHUD: z.boolean(),
showHUD: z.boolean().default(true),
/**
* Whether to clip lines and shapes to the generation bounding box. If disabled, lines and shapes will be clipped to
* the canvas bounds.
*/
clipToBbox: z.boolean(),
clipToBbox: z.boolean().default(false),
/**
* Whether to show a dynamic grid on the canvas. If disabled, a checkerboard pattern will be shown instead.
*/
dynamicGrid: z.boolean(),
dynamicGrid: z.boolean().default(false),
/**
* Whether to invert the scroll direction when adjusting the brush or eraser width with the scroll wheel.
*/
invertScrollForToolWidth: z.boolean(),
invertScrollForToolWidth: z.boolean().default(false),
/**
* The width of the brush tool.
*/
brushWidth: z.int().gt(0),
brushWidth: z.int().gt(0).default(50),
/**
* The width of the eraser tool.
*/
eraserWidth: z.int().gt(0),
eraserWidth: z.int().gt(0).default(50),
/**
* The color to use when drawing lines or filling shapes.
*/
color: zRgbaColor,
color: zRgbaColor.default({ r: 31, g: 160, b: 224, a: 1 }), // invokeBlue.500
/**
* Whether to composite inpainted/outpainted regions back onto the source image when saving canvas generations.
*
@@ -45,77 +44,57 @@ const zCanvasSettingsState = z.object({
*
* When `sendToCanvas` is disabled, this setting is ignored, masked regions will always be composited.
*/
outputOnlyMaskedRegions: z.boolean(),
outputOnlyMaskedRegions: z.boolean().default(true),
/**
* Whether to automatically process the operations like filtering and auto-masking.
*/
autoProcess: z.boolean(),
autoProcess: z.boolean().default(true),
/**
* The snap-to-grid setting for the canvas.
*/
snapToGrid: z.boolean(),
snapToGrid: z.boolean().default(true),
/**
* Whether to show progress on the canvas when generating images.
*/
showProgressOnCanvas: z.boolean(),
showProgressOnCanvas: z.boolean().default(true),
/**
* Whether to show the bounding box overlay on the canvas.
*/
bboxOverlay: z.boolean(),
bboxOverlay: z.boolean().default(false),
/**
* Whether to preserve the masked region instead of inpainting it.
*/
preserveMask: z.boolean(),
preserveMask: z.boolean().default(false),
/**
* Whether to show only raster layers while staging.
*/
isolatedStagingPreview: z.boolean(),
isolatedStagingPreview: z.boolean().default(true),
/**
* Whether to show only the selected layer while filtering, transforming, or doing other operations.
*/
isolatedLayerPreview: z.boolean(),
isolatedLayerPreview: z.boolean().default(true),
/**
* Whether to use pressure sensitivity for the brush and eraser tool when a pen device is used.
*/
pressureSensitivity: z.boolean(),
pressureSensitivity: z.boolean().default(true),
/**
* Whether to show the rule of thirds composition guide overlay on the canvas.
*/
ruleOfThirds: z.boolean(),
ruleOfThirds: z.boolean().default(false),
/**
* Whether to save all staging images to the gallery instead of keeping them as intermediate images.
*/
saveAllImagesToGallery: z.boolean(),
saveAllImagesToGallery: z.boolean().default(false),
/**
* The auto-switch mode for the canvas staging area.
*/
stagingAreaAutoSwitch: zAutoSwitchMode,
stagingAreaAutoSwitch: zAutoSwitchMode.default('switch_on_start'),
});
type CanvasSettingsState = z.infer<typeof zCanvasSettingsState>;
const getInitialState = (): CanvasSettingsState => ({
showHUD: true,
clipToBbox: false,
dynamicGrid: false,
invertScrollForToolWidth: false,
brushWidth: 50,
eraserWidth: 50,
color: { r: 31, g: 160, b: 224, a: 1 }, // invokeBlue.500
outputOnlyMaskedRegions: true,
autoProcess: true,
snapToGrid: true,
showProgressOnCanvas: true,
bboxOverlay: false,
preserveMask: false,
isolatedStagingPreview: true,
isolatedLayerPreview: true,
pressureSensitivity: true,
ruleOfThirds: false,
saveAllImagesToGallery: false,
stagingAreaAutoSwitch: 'switch_on_start',
});
const getInitialState = () => zCanvasSettingsState.parse({});
const slice = createSlice({
export const canvasSettingsSlice = createSlice({
name: 'canvasSettings',
initialState: getInitialState(),
reducers: {
@@ -205,15 +184,18 @@ export const {
settingsRuleOfThirdsToggled,
settingsSaveAllImagesToGalleryToggled,
settingsStagingAreaAutoSwitchChanged,
} = slice.actions;
} = canvasSettingsSlice.actions;
export const canvasSettingsSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zCanvasSettingsState,
getInitialState,
persistConfig: {
migrate: (state) => zCanvasSettingsState.parse(state),
},
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrate = (state: any): any => {
return state;
};
export const canvasSettingsPersistConfig: PersistConfig<CanvasSettingsState> = {
name: canvasSettingsSlice.name,
initialState: getInitialState(),
migrate,
persistDenylist: [],
};
export const selectCanvasSettingsSlice = (s: RootState) => s.canvasSettings;

View File

@@ -1,6 +1,6 @@
import type { PayloadAction, UnknownAction } from '@reduxjs/toolkit';
import { createSlice, isAnyOf } from '@reduxjs/toolkit';
import type { SliceConfig } from 'app/store/types';
import type { PersistConfig } from 'app/store/store';
import { moveOneToEnd, moveOneToStart, moveToEnd, moveToStart } from 'common/util/arrayUtils';
import { deepClone } from 'common/util/deepClone';
import { roundDownToMultiple, roundToMultiple } from 'common/util/roundDownToMultiple';
@@ -80,7 +80,6 @@ import {
isFLUXReduxConfig,
isImagenAspectRatioID,
isIPAdapterConfig,
zCanvasState,
} from './types';
import {
converters,
@@ -96,7 +95,7 @@ import {
initialT2IAdapter,
} from './util';
const slice = createSlice({
export const canvasSlice = createSlice({
name: 'canvas',
initialState: getInitialCanvasState(),
reducers: {
@@ -1619,6 +1618,7 @@ export const {
entityArrangedToBack,
entityOpacityChanged,
entitiesReordered,
allEntitiesDeleted,
allEntitiesOfTypeIsHiddenToggled,
allNonRasterLayersIsHiddenToggled,
// bbox
@@ -1676,7 +1676,19 @@ export const {
inpaintMaskDenoiseLimitChanged,
inpaintMaskDenoiseLimitDeleted,
// inpaintMaskRecalled,
} = slice.actions;
} = canvasSlice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrate = (state: any): any => {
return state;
};
export const canvasPersistConfig: PersistConfig<CanvasState> = {
name: canvasSlice.name,
initialState: getInitialCanvasState(),
migrate,
persistDenylist: [],
};
const syncScaledSize = (state: CanvasState) => {
if (API_BASE_MODELS.includes(state.bbox.modelBase)) {
@@ -1699,14 +1711,14 @@ const syncScaledSize = (state: CanvasState) => {
let filter = true;
const canvasUndoableConfig: UndoableOptions<CanvasState, UnknownAction> = {
export const canvasUndoableConfig: UndoableOptions<CanvasState, UnknownAction> = {
limit: 64,
undoType: canvasUndo.type,
redoType: canvasRedo.type,
clearHistoryType: canvasClearHistory.type,
filter: (action, _state, _history) => {
// Ignore all actions from other slices
if (!action.type.startsWith(slice.name)) {
if (!action.type.startsWith(canvasSlice.name)) {
return false;
}
// Throttle rapid actions of the same type
@@ -1717,18 +1729,6 @@ const canvasUndoableConfig: UndoableOptions<CanvasState, UnknownAction> = {
// debug: import.meta.env.MODE === 'development',
};
export const canvasSliceConfig: SliceConfig<typeof slice> = {
slice,
getInitialState: getInitialCanvasState,
schema: zCanvasState,
persistConfig: {
migrate: (state) => zCanvasState.parse(state),
},
undoableConfig: {
reduxUndoOptions: canvasUndoableConfig,
},
};
const doNotGroupMatcher = isAnyOf(entityBrushLineAdded, entityEraserLineAdded, entityRectAdded);
// Store rapid actions of the same type at most once every x time.

View File

@@ -1,29 +1,27 @@
import { createSelector, createSlice, type PayloadAction } from '@reduxjs/toolkit';
import { EMPTY_ARRAY } from 'app/store/constants';
import type { RootState } from 'app/store/store';
import type { PersistConfig, RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import type { SliceConfig } from 'app/store/types';
import { isPlainObject } from 'es-toolkit';
import { deepClone } from 'common/util/deepClone';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { useMemo } from 'react';
import { queueApi } from 'services/api/endpoints/queue';
import { assert } from 'tsafe';
import z from 'zod';
const zCanvasStagingAreaState = z.object({
_version: z.literal(1),
canvasSessionId: z.string(),
canvasDiscardedQueueItems: z.array(z.number().int()),
});
type CanvasStagingAreaState = z.infer<typeof zCanvasStagingAreaState>;
type CanvasStagingAreaState = {
_version: 1;
canvasSessionId: string;
canvasDiscardedQueueItems: number[];
};
const getInitialState = (): CanvasStagingAreaState => ({
const INITIAL_STATE: CanvasStagingAreaState = {
_version: 1,
canvasSessionId: getPrefixedId('canvas'),
canvasDiscardedQueueItems: [],
});
};
const slice = createSlice({
const getInitialState = (): CanvasStagingAreaState => deepClone(INITIAL_STATE);
export const canvasSessionSlice = createSlice({
name: 'canvasSession',
initialState: getInitialState(),
reducers: {
@@ -50,26 +48,26 @@ const slice = createSlice({
},
});
export const { canvasSessionReset, canvasQueueItemDiscarded } = slice.actions;
export const { canvasSessionReset, canvasQueueItemDiscarded } = canvasSessionSlice.actions;
export const canvasSessionSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zCanvasStagingAreaState,
getInitialState,
persistConfig: {
migrate: (state) => {
assert(isPlainObject(state));
if (!('_version' in state)) {
state._version = 1;
state.canvasSessionId = state.canvasSessionId ?? getPrefixedId('canvas');
}
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrate = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
state.canvasSessionId = state.canvasSessionId ?? getPrefixedId('canvas');
}
return zCanvasStagingAreaState.parse(state);
},
},
return state;
};
export const selectCanvasSessionSlice = (s: RootState) => s[slice.name];
export const canvasStagingAreaPersistConfig: PersistConfig<CanvasStagingAreaState> = {
name: canvasSessionSlice.name,
initialState: getInitialState(),
migrate,
persistDenylist: [],
};
export const selectCanvasSessionSlice = (s: RootState) => s[canvasSessionSlice.name];
export const selectCanvasSessionId = createSelector(selectCanvasSessionSlice, ({ canvasSessionId }) => canvasSessionId);
const selectDiscardedItems = createSelector(

View File

@@ -166,7 +166,7 @@ const _zFilterConfig = z.discriminatedUnion('type', [
]);
export type FilterConfig = z.infer<typeof _zFilterConfig>;
export const zFilterType = z.enum([
const zFilterType = z.enum([
'adjust_image',
'canny_edge_detection',
'color_map',

View File

@@ -1,32 +1,30 @@
import { createSelector, createSlice, type PayloadAction } from '@reduxjs/toolkit';
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import type { PersistConfig, RootState } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
import { paramsReset } from 'features/controlLayers/store/paramsSlice';
import { type LoRA, zLoRA } from 'features/controlLayers/store/types';
import type { LoRA } from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common';
import type { LoRAModelConfig } from 'services/api/types';
import { v4 as uuidv4 } from 'uuid';
import z from 'zod';
const zLoRAsState = z.object({
loras: z.array(zLoRA),
});
type LoRAsState = z.infer<typeof zLoRAsState>;
type LoRAsState = {
loras: LoRA[];
};
const defaultLoRAConfig: Pick<LoRA, 'weight' | 'isEnabled'> = {
weight: 0.75,
isEnabled: true,
};
const getInitialState = (): LoRAsState => ({
const initialState: LoRAsState = {
loras: [],
});
};
const selectLoRA = (state: LoRAsState, id: string) => state.loras.find((lora) => lora.id === id);
const slice = createSlice({
export const lorasSlice = createSlice({
name: 'loras',
initialState: getInitialState(),
initialState,
reducers: {
loraAdded: {
reducer: (state, action: PayloadAction<{ model: LoRAModelConfig; id: string }>) => {
@@ -68,21 +66,24 @@ const slice = createSlice({
extraReducers(builder) {
builder.addCase(paramsReset, () => {
// When a new session is requested, clear all LoRAs
return getInitialState();
return deepClone(initialState);
});
},
});
export const { loraAdded, loraRecalled, loraDeleted, loraWeightChanged, loraIsEnabledChanged, loraAllDeleted } =
slice.actions;
lorasSlice.actions;
export const lorasSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zLoRAsState,
getInitialState,
persistConfig: {
migrate: (state) => zLoRAsState.parse(state),
},
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrate = (state: any): any => {
return state;
};
export const lorasPersistConfig: PersistConfig<LoRAsState> = {
name: lorasSlice.name,
initialState,
migrate,
persistDenylist: [],
};
export const selectLoRAsSlice = (state: RootState) => state.loras;

View File

@@ -1,7 +1,6 @@
import type { PayloadAction, Selector } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit';
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import type { PersistConfig, RootState } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
import { roundDownToMultiple, roundToMultiple } from 'common/util/roundDownToMultiple';
import { clamp } from 'es-toolkit/compat';
@@ -16,7 +15,6 @@ import {
isChatGPT4oAspectRatioID,
isFluxKontextAspectRatioID,
isImagenAspectRatioID,
zParamsState,
} from 'features/controlLayers/store/types';
import { calculateNewSize } from 'features/controlLayers/util/getScaledBoundingBoxDimensions';
import { CLIP_SKIP_MAP } from 'features/parameters/types/constants';
@@ -42,7 +40,7 @@ import { getGridSize, getIsSizeOptimal, getOptimalDimension } from 'features/par
import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/api/endpoints/models';
import { isNonRefinerMainModelConfig } from 'services/api/types';
const slice = createSlice({
export const paramsSlice = createSlice({
name: 'params',
initialState: getInitialParamsState(),
reducers: {
@@ -94,12 +92,7 @@ const slice = createSlice({
state,
action: PayloadAction<{ model: ParameterModel | null; previousModel?: ParameterModel | null }>
) => {
const { previousModel } = action.payload;
const result = zParamsState.shape.model.safeParse(action.payload.model);
if (!result.success) {
return;
}
const model = result.data;
const { model, previousModel } = action.payload;
state.model = model;
// If the model base changes (e.g. SD1.5 -> SDXL), we need to change a few things
@@ -118,53 +111,25 @@ const slice = createSlice({
},
vaeSelected: (state, action: PayloadAction<ParameterVAEModel | null>) => {
// null is a valid VAE!
const result = zParamsState.shape.vae.safeParse(action.payload);
if (!result.success) {
return;
}
state.vae = result.data;
state.vae = action.payload;
},
fluxVAESelected: (state, action: PayloadAction<ParameterVAEModel | null>) => {
const result = zParamsState.shape.fluxVAE.safeParse(action.payload);
if (!result.success) {
return;
}
state.fluxVAE = result.data;
state.fluxVAE = action.payload;
},
t5EncoderModelSelected: (state, action: PayloadAction<ParameterT5EncoderModel | null>) => {
const result = zParamsState.shape.t5EncoderModel.safeParse(action.payload);
if (!result.success) {
return;
}
state.t5EncoderModel = result.data;
state.t5EncoderModel = action.payload;
},
controlLoRAModelSelected: (state, action: PayloadAction<ParameterControlLoRAModel | null>) => {
const result = zParamsState.shape.controlLora.safeParse(action.payload);
if (!result.success) {
return;
}
state.controlLora = result.data;
state.controlLora = action.payload;
},
clipEmbedModelSelected: (state, action: PayloadAction<ParameterCLIPEmbedModel | null>) => {
const result = zParamsState.shape.clipEmbedModel.safeParse(action.payload);
if (!result.success) {
return;
}
state.clipEmbedModel = result.data;
state.clipEmbedModel = action.payload;
},
clipLEmbedModelSelected: (state, action: PayloadAction<ParameterCLIPLEmbedModel | null>) => {
const result = zParamsState.shape.clipLEmbedModel.safeParse(action.payload);
if (!result.success) {
return;
}
state.clipLEmbedModel = result.data;
state.clipLEmbedModel = action.payload;
},
clipGEmbedModelSelected: (state, action: PayloadAction<ParameterCLIPGEmbedModel | null>) => {
const result = zParamsState.shape.clipGEmbedModel.safeParse(action.payload);
if (!result.success) {
return;
}
state.clipGEmbedModel = result.data;
state.clipGEmbedModel = action.payload;
},
vaePrecisionChanged: (state, action: PayloadAction<ParameterPrecision>) => {
state.vaePrecision = action.payload;
@@ -191,11 +156,7 @@ const slice = createSlice({
state.shouldConcatPrompts = action.payload;
},
refinerModelChanged: (state, action: PayloadAction<ParameterSDXLRefinerModel | null>) => {
const result = zParamsState.shape.refinerModel.safeParse(action.payload);
if (!result.success) {
return;
}
state.refinerModel = result.data;
state.refinerModel = action.payload;
},
setRefinerSteps: (state, action: PayloadAction<number>) => {
state.refinerSteps = action.payload;
@@ -436,15 +397,18 @@ export const {
syncedToOptimalDimension,
paramsReset,
} = slice.actions;
} = paramsSlice.actions;
export const paramsSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zParamsState,
getInitialState: getInitialParamsState,
persistConfig: {
migrate: (state) => zParamsState.parse(state),
},
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrate = (state: any): any => {
return state;
};
export const paramsPersistConfig: PersistConfig<ParamsState> = {
name: paramsSlice.name,
initialState: getInitialParamsState(),
migrate,
persistDenylist: [],
};
export const selectParamsSlice = (state: RootState) => state.params;

View File

@@ -2,8 +2,7 @@ import { objectEquals } from '@observ33r/object-equals';
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import type { PersistConfig, RootState } from 'app/store/store';
import { clamp } from 'es-toolkit/compat';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import type { FLUXReduxImageInfluence, RefImagesState } from 'features/controlLayers/store/types';
@@ -19,7 +18,7 @@ import { assert } from 'tsafe';
import type { PartialDeep } from 'type-fest';
import type { CLIPVisionModelV2, IPMethodV2, RefImageState } from './types';
import { getInitialRefImagesState, isFLUXReduxConfig, isIPAdapterConfig, zRefImagesState } from './types';
import { getInitialRefImagesState, isFLUXReduxConfig, isIPAdapterConfig } from './types';
import {
getReferenceImageState,
imageDTOToImageWithDims,
@@ -37,7 +36,7 @@ type PayloadActionWithId<T = void> = T extends void
} & T
>;
const slice = createSlice({
export const refImagesSlice = createSlice({
name: 'refImages',
initialState: getInitialRefImagesState(),
reducers: {
@@ -264,16 +263,18 @@ export const {
refImageFLUXReduxImageInfluenceChanged,
refImageIsEnabledToggled,
refImagesRecalled,
} = slice.actions;
} = refImagesSlice.actions;
export const refImagesSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zRefImagesState,
getInitialState: getInitialRefImagesState,
persistConfig: {
migrate: (state) => zRefImagesState.parse(state),
persistDenylist: ['selectedEntityId', 'isPanelOpen'],
},
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrate = (state: any): any => {
return state;
};
export const refImagesPersistConfig: PersistConfig<RefImagesState> = {
name: refImagesSlice.name,
initialState: getInitialRefImagesState(),
migrate,
persistDenylist: ['selectedEntityId', 'isPanelOpen'],
};
export const selectRefImagesSlice = (state: RootState) => state.refImages;

View File

@@ -1,7 +1,9 @@
import { deepClone } from 'common/util/deepClone';
import type { CanvasEntityAdapter } from 'features/controlLayers/konva/CanvasEntity/types';
import { fetchModelConfigByIdentifier } from 'features/metadata/util/modelFetchingHelpers';
import type { ProgressImage } from 'features/nodes/types/common';
import { zMainModelBase, zModelIdentifierField } from 'features/nodes/types/common';
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
import {
zParameterCanvasCoherenceMode,
zParameterCFGRescaleMultiplier,
@@ -27,17 +29,33 @@ import {
zParameterT5EncoderModel,
zParameterVAEModel,
} from 'features/parameters/types/parameterSchemas';
import { getImageDTOSafe } from 'services/api/endpoints/images';
import type { JsonObject } from 'type-fest';
import { z } from 'zod';
const zId = z.string().min(1);
const zName = z.string().min(1).nullable();
export const zImageWithDims = z.object({
image_name: z.string(),
width: z.number().int().positive(),
height: z.number().int().positive(),
const zServerValidatedModelIdentifierField = zModelIdentifierField.refine(async (modelIdentifier) => {
try {
await fetchModelConfigByIdentifier(modelIdentifier);
return true;
} catch {
return false;
}
});
const zImageWithDims = z
.object({
image_name: z.string(),
width: z.number().int().positive(),
height: z.number().int().positive(),
})
.refine(async (v) => {
const { image_name } = v;
const imageDTO = await getImageDTOSafe(image_name, { forceRefetch: true });
return imageDTO !== null;
});
export type ImageWithDims = z.infer<typeof zImageWithDims>;
const zImageWithDimsDataURL = z.object({
@@ -235,7 +253,7 @@ export type CanvasObjectState = z.infer<typeof zCanvasObjectState>;
const zIPAdapterConfig = z.object({
type: z.literal('ip_adapter'),
image: zImageWithDims.nullable(),
model: zModelIdentifierField.nullable(),
model: zServerValidatedModelIdentifierField.nullable(),
weight: z.number().gte(-1).lte(2),
beginEndStepPct: zBeginEndStepPct,
method: zIPMethodV2,
@@ -250,7 +268,7 @@ export type FLUXReduxImageInfluence = z.infer<typeof zFLUXReduxImageInfluence>;
const zFLUXReduxConfig = z.object({
type: z.literal('flux_redux'),
image: zImageWithDims.nullable(),
model: zModelIdentifierField.nullable(),
model: zServerValidatedModelIdentifierField.nullable(),
imageInfluence: zFLUXReduxImageInfluence.default('highest'),
});
export type FLUXReduxConfig = z.infer<typeof zFLUXReduxConfig>;
@@ -263,14 +281,14 @@ const zChatGPT4oReferenceImageConfig = z.object({
* But we use a model drop down to switch between different ref image types, so there needs to be a model here else
* there will be no way to switch between ref image types.
*/
model: zModelIdentifierField.nullable(),
model: zServerValidatedModelIdentifierField.nullable(),
});
export type ChatGPT4oReferenceImageConfig = z.infer<typeof zChatGPT4oReferenceImageConfig>;
const zFluxKontextReferenceImageConfig = z.object({
type: z.literal('flux_kontext_reference_image'),
image: zImageWithDims.nullable(),
model: zModelIdentifierField.nullable(),
model: zServerValidatedModelIdentifierField.nullable(),
});
export type FluxKontextReferenceImageConfig = z.infer<typeof zFluxKontextReferenceImageConfig>;
@@ -342,7 +360,7 @@ export type CanvasInpaintMaskState = z.infer<typeof zCanvasInpaintMaskState>;
const zControlNetConfig = z.object({
type: z.literal('controlnet'),
model: zModelIdentifierField.nullable(),
model: zServerValidatedModelIdentifierField.nullable(),
weight: z.number().gte(-1).lte(2),
beginEndStepPct: zBeginEndStepPct,
controlMode: zControlModeV2,
@@ -351,7 +369,7 @@ export type ControlNetConfig = z.infer<typeof zControlNetConfig>;
const zT2IAdapterConfig = z.object({
type: z.literal('t2i_adapter'),
model: zModelIdentifierField.nullable(),
model: zServerValidatedModelIdentifierField.nullable(),
weight: z.number().gte(-1).lte(2),
beginEndStepPct: zBeginEndStepPct,
});
@@ -360,7 +378,7 @@ export type T2IAdapterConfig = z.infer<typeof zT2IAdapterConfig>;
const zControlLoRAConfig = z.object({
type: z.literal('control_lora'),
weight: z.number().gte(-1).lte(2),
model: zModelIdentifierField.nullable(),
model: zServerValidatedModelIdentifierField.nullable(),
});
export type ControlLoRAConfig = z.infer<typeof zControlLoRAConfig>;
@@ -406,13 +424,12 @@ export const zCanvasEntityIdentifer = z.object({
});
export type CanvasEntityIdentifier<T extends CanvasEntityType = CanvasEntityType> = { id: string; type: T };
export const zLoRA = z.object({
id: z.string(),
isEnabled: z.boolean(),
model: zModelIdentifierField,
weight: z.number().gte(-1).lte(2),
});
export type LoRA = z.infer<typeof zLoRA>;
export type LoRA = {
id: string;
isEnabled: boolean;
model: ParameterLoRAModel;
weight: number;
};
export type EphemeralProgressImage = { sessionId: string; image: ProgressImage };
@@ -505,108 +522,62 @@ const zDimensionsState = z.object({
aspectRatio: zAspectRatioConfig,
});
export const zParamsState = z.object({
maskBlur: z.number(),
maskBlurMethod: zParameterMaskBlurMethod,
canvasCoherenceMode: zParameterCanvasCoherenceMode,
canvasCoherenceMinDenoise: zParameterStrength,
canvasCoherenceEdgeSize: z.number(),
infillMethod: z.string(),
infillTileSize: z.number(),
infillPatchmatchDownscaleSize: z.number(),
infillColorValue: zRgbaColor,
cfgScale: zParameterCFGScale,
cfgRescaleMultiplier: zParameterCFGRescaleMultiplier,
guidance: zParameterGuidance,
img2imgStrength: zParameterStrength,
optimizedDenoisingEnabled: z.boolean(),
iterations: z.number(),
scheduler: zParameterScheduler,
upscaleScheduler: zParameterScheduler,
upscaleCfgScale: zParameterCFGScale,
seed: zParameterSeed,
shouldRandomizeSeed: z.boolean(),
steps: zParameterSteps,
model: zParameterModel.nullable(),
vae: zParameterVAEModel.nullable(),
vaePrecision: zParameterPrecision,
fluxVAE: zParameterVAEModel.nullable(),
seamlessXAxis: z.boolean(),
seamlessYAxis: z.boolean(),
clipSkip: z.number(),
shouldUseCpuNoise: z.boolean(),
positivePrompt: zParameterPositivePrompt,
negativePrompt: zParameterNegativePrompt,
positivePrompt2: zParameterPositiveStylePromptSDXL,
negativePrompt2: zParameterNegativeStylePromptSDXL,
shouldConcatPrompts: z.boolean(),
refinerModel: zParameterSDXLRefinerModel.nullable(),
refinerSteps: z.number(),
refinerCFGScale: z.number(),
refinerScheduler: zParameterScheduler,
refinerPositiveAestheticScore: z.number(),
refinerNegativeAestheticScore: z.number(),
refinerStart: z.number(),
t5EncoderModel: zParameterT5EncoderModel.nullable(),
clipEmbedModel: zParameterCLIPEmbedModel.nullable(),
clipLEmbedModel: zParameterCLIPLEmbedModel.nullable(),
clipGEmbedModel: zParameterCLIPGEmbedModel.nullable(),
controlLora: zParameterControlLoRAModel.nullable(),
dimensions: zDimensionsState,
const zParamsState = z.object({
maskBlur: z.number().default(16),
maskBlurMethod: zParameterMaskBlurMethod.default('box'),
canvasCoherenceMode: zParameterCanvasCoherenceMode.default('Gaussian Blur'),
canvasCoherenceMinDenoise: zParameterStrength.default(0),
canvasCoherenceEdgeSize: z.number().default(16),
infillMethod: z.string().default('lama'),
infillTileSize: z.number().default(32),
infillPatchmatchDownscaleSize: z.number().default(1),
infillColorValue: zRgbaColor.default({ r: 0, g: 0, b: 0, a: 1 }),
cfgScale: zParameterCFGScale.default(7.5),
cfgRescaleMultiplier: zParameterCFGRescaleMultiplier.default(0),
guidance: zParameterGuidance.default(4),
img2imgStrength: zParameterStrength.default(0.75),
optimizedDenoisingEnabled: z.boolean().default(true),
iterations: z.number().default(1),
scheduler: zParameterScheduler.default('dpmpp_3m_k'),
upscaleScheduler: zParameterScheduler.default('kdpm_2'),
upscaleCfgScale: zParameterCFGScale.default(2),
seed: zParameterSeed.default(0),
shouldRandomizeSeed: z.boolean().default(true),
steps: zParameterSteps.default(30),
model: zParameterModel.nullable().default(null),
vae: zParameterVAEModel.nullable().default(null),
vaePrecision: zParameterPrecision.default('fp32'),
fluxVAE: zParameterVAEModel.nullable().default(null),
seamlessXAxis: z.boolean().default(false),
seamlessYAxis: z.boolean().default(false),
clipSkip: z.number().default(0),
shouldUseCpuNoise: z.boolean().default(true),
positivePrompt: zParameterPositivePrompt.default(''),
// Negative prompt may be disabled, in which case it will be null
negativePrompt: zParameterNegativePrompt.default(null),
positivePrompt2: zParameterPositiveStylePromptSDXL.default(''),
negativePrompt2: zParameterNegativeStylePromptSDXL.default(''),
shouldConcatPrompts: z.boolean().default(true),
refinerModel: zParameterSDXLRefinerModel.nullable().default(null),
refinerSteps: z.number().default(20),
refinerCFGScale: z.number().default(7.5),
refinerScheduler: zParameterScheduler.default('euler'),
refinerPositiveAestheticScore: z.number().default(6),
refinerNegativeAestheticScore: z.number().default(2.5),
refinerStart: z.number().default(0.8),
t5EncoderModel: zParameterT5EncoderModel.nullable().default(null),
clipEmbedModel: zParameterCLIPEmbedModel.nullable().default(null),
clipLEmbedModel: zParameterCLIPLEmbedModel.nullable().default(null),
clipGEmbedModel: zParameterCLIPGEmbedModel.nullable().default(null),
controlLora: zParameterControlLoRAModel.nullable().default(null),
dimensions: zDimensionsState.default({
rect: { x: 0, y: 0, width: 512, height: 512 },
aspectRatio: DEFAULT_ASPECT_RATIO_CONFIG,
}),
});
export type ParamsState = z.infer<typeof zParamsState>;
export const getInitialParamsState = (): ParamsState => ({
maskBlur: 16,
maskBlurMethod: 'box',
canvasCoherenceMode: 'Gaussian Blur',
canvasCoherenceMinDenoise: 0,
canvasCoherenceEdgeSize: 16,
infillMethod: 'lama',
infillTileSize: 32,
infillPatchmatchDownscaleSize: 1,
infillColorValue: { r: 0, g: 0, b: 0, a: 1 },
cfgScale: 7.5,
cfgRescaleMultiplier: 0,
guidance: 4,
img2imgStrength: 0.75,
optimizedDenoisingEnabled: true,
iterations: 1,
scheduler: 'dpmpp_3m_k',
upscaleScheduler: 'kdpm_2',
upscaleCfgScale: 2,
seed: 0,
shouldRandomizeSeed: true,
steps: 30,
model: null,
vae: null,
vaePrecision: 'fp32',
fluxVAE: null,
seamlessXAxis: false,
seamlessYAxis: false,
clipSkip: 0,
shouldUseCpuNoise: true,
positivePrompt: '',
negativePrompt: null,
positivePrompt2: '',
negativePrompt2: '',
shouldConcatPrompts: true,
refinerModel: null,
refinerSteps: 20,
refinerCFGScale: 7.5,
refinerScheduler: 'euler',
refinerPositiveAestheticScore: 6,
refinerNegativeAestheticScore: 2.5,
refinerStart: 0.8,
t5EncoderModel: null,
clipEmbedModel: null,
clipLEmbedModel: null,
clipGEmbedModel: null,
controlLora: null,
dimensions: {
rect: { x: 0, y: 0, width: 512, height: 512 },
aspectRatio: deepClone(DEFAULT_ASPECT_RATIO_CONFIG),
},
});
const INITIAL_PARAMS_STATE = zParamsState.parse({});
export const getInitialParamsState = () => deepClone(INITIAL_PARAMS_STATE);
const zInpaintMasks = z.object({
isHidden: z.boolean(),
@@ -624,45 +595,38 @@ const zRegionalGuidance = z.object({
isHidden: z.boolean(),
entities: z.array(zCanvasRegionalGuidanceState),
});
export const zCanvasState = z.object({
_version: z.literal(3),
selectedEntityIdentifier: zCanvasEntityIdentifer.nullable(),
bookmarkedEntityIdentifier: zCanvasEntityIdentifer.nullable(),
inpaintMasks: zInpaintMasks,
rasterLayers: zRasterLayers,
controlLayers: zControlLayers,
regionalGuidance: zRegionalGuidance,
bbox: zBboxState,
});
export type CanvasState = z.infer<typeof zCanvasState>;
export const getInitialCanvasState = (): CanvasState => ({
_version: 3,
selectedEntityIdentifier: null,
bookmarkedEntityIdentifier: null,
inpaintMasks: { isHidden: false, entities: [] },
rasterLayers: { isHidden: false, entities: [] },
controlLayers: { isHidden: false, entities: [] },
regionalGuidance: { isHidden: false, entities: [] },
bbox: {
const zCanvasState = z.object({
_version: z.literal(3).default(3),
selectedEntityIdentifier: zCanvasEntityIdentifer.nullable().default(null),
bookmarkedEntityIdentifier: zCanvasEntityIdentifer.nullable().default(null),
inpaintMasks: zInpaintMasks.default({ isHidden: false, entities: [] }),
rasterLayers: zRasterLayers.default({ isHidden: false, entities: [] }),
controlLayers: zControlLayers.default({ isHidden: false, entities: [] }),
regionalGuidance: zRegionalGuidance.default({ isHidden: false, entities: [] }),
bbox: zBboxState.default({
rect: { x: 0, y: 0, width: 512, height: 512 },
aspectRatio: deepClone(DEFAULT_ASPECT_RATIO_CONFIG),
aspectRatio: DEFAULT_ASPECT_RATIO_CONFIG,
scaleMethod: 'auto',
scaledSize: { width: 512, height: 512 },
modelBase: 'sd-1',
},
}),
});
export type CanvasState = z.infer<typeof zCanvasState>;
export const zRefImagesState = z.object({
selectedEntityId: z.string().nullable(),
isPanelOpen: z.boolean(),
entities: z.array(zRefImageState),
const zRefImagesState = z.object({
selectedEntityId: z.string().nullable().default(null),
isPanelOpen: z.boolean().default(false),
entities: z.array(zRefImageState).default(() => []),
});
export type RefImagesState = z.infer<typeof zRefImagesState>;
export const getInitialRefImagesState = (): RefImagesState => ({
selectedEntityId: null,
isPanelOpen: false,
entities: [],
});
const INITIAL_REF_IMAGES_STATE = zRefImagesState.parse({});
export const getInitialRefImagesState = () => deepClone(INITIAL_REF_IMAGES_STATE);
/**
* Gets a fresh canvas initial state with no references in memory to existing objects.
*/
const CANVAS_INITIAL_STATE = zCanvasState.parse({});
export const getInitialCanvasState = () => deepClone(CANVAS_INITIAL_STATE);
export const zCanvasReferenceImageState_OLD = zCanvasEntityBase.extend({
type: z.literal('reference_image'),

View File

@@ -1,29 +1,25 @@
import type { PayloadAction, Selector } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit';
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import type { PersistConfig, RootState } from 'app/store/store';
import { buildZodTypeGuard } from 'common/util/zodUtils';
import { isPlainObject } from 'es-toolkit';
import { assert } from 'tsafe';
import { z } from 'zod';
const zSeedBehaviour = z.enum(['PER_ITERATION', 'PER_PROMPT']);
export const isSeedBehaviour = buildZodTypeGuard(zSeedBehaviour);
export type SeedBehaviour = z.infer<typeof zSeedBehaviour>;
const zDynamicPromptsState = z.object({
_version: z.literal(1),
maxPrompts: z.number().int().min(1).max(1000),
combinatorial: z.boolean(),
prompts: z.array(z.string()),
parsingError: z.string().nullish(),
isError: z.boolean(),
isLoading: z.boolean(),
seedBehaviour: zSeedBehaviour,
});
export type DynamicPromptsState = z.infer<typeof zDynamicPromptsState>;
export interface DynamicPromptsState {
_version: 1;
maxPrompts: number;
combinatorial: boolean;
prompts: string[];
parsingError: string | undefined | null;
isError: boolean;
isLoading: boolean;
seedBehaviour: SeedBehaviour;
}
const getInitialState = (): DynamicPromptsState => ({
const initialDynamicPromptsState: DynamicPromptsState = {
_version: 1,
maxPrompts: 100,
combinatorial: true,
@@ -32,11 +28,11 @@ const getInitialState = (): DynamicPromptsState => ({
isError: false,
isLoading: false,
seedBehaviour: 'PER_ITERATION',
});
};
const slice = createSlice({
export const dynamicPromptsSlice = createSlice({
name: 'dynamicPrompts',
initialState: getInitialState(),
initialState: initialDynamicPromptsState,
reducers: {
maxPromptsChanged: (state, action: PayloadAction<number>) => {
state.maxPrompts = action.payload;
@@ -67,22 +63,21 @@ export const {
isErrorChanged,
isLoadingChanged,
seedBehaviourChanged,
} = slice.actions;
} = dynamicPromptsSlice.actions;
export const dynamicPromptsSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zDynamicPromptsState,
getInitialState,
persistConfig: {
migrate: (state) => {
assert(isPlainObject(state));
if (!('_version' in state)) {
state._version = 1;
}
return zDynamicPromptsState.parse(state);
},
persistDenylist: ['prompts', 'parsingError', 'isError', 'isLoading'],
},
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrateDynamicPromptsState = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
return state;
};
export const dynamicPromptsPersistConfig: PersistConfig<DynamicPromptsState> = {
name: dynamicPromptsSlice.name,
initialState: initialDynamicPromptsState,
migrate: migrateDynamicPromptsState,
persistDenylist: ['prompts'],
};
export const selectDynamicPromptsSlice = (state: RootState) => state.dynamicPrompts;

View File

@@ -21,14 +21,7 @@ export const ImageMenuItemNewCanvasFromImageSubMenu = memo(() => {
const onClickNewCanvasWithRasterLayerFromImage = useCallback(async () => {
const { dispatch, getState } = store;
await navigationApi.focusPanel('canvas', WORKSPACE_PANEL_ID);
await newCanvasFromImage({
imageDTO,
withResize: false,
withInpaintMask: true,
type: 'raster_layer',
dispatch,
getState,
});
await newCanvasFromImage({ imageDTO, withResize: false, type: 'raster_layer', dispatch, getState });
toast({
id: 'SENT_TO_CANVAS',
title: t('toast.sentToCanvas'),
@@ -39,14 +32,7 @@ export const ImageMenuItemNewCanvasFromImageSubMenu = memo(() => {
const onClickNewCanvasWithControlLayerFromImage = useCallback(async () => {
const { dispatch, getState } = store;
await navigationApi.focusPanel('canvas', WORKSPACE_PANEL_ID);
await newCanvasFromImage({
imageDTO,
withResize: false,
withInpaintMask: true,
type: 'control_layer',
dispatch,
getState,
});
await newCanvasFromImage({ imageDTO, withResize: false, type: 'control_layer', dispatch, getState });
toast({
id: 'SENT_TO_CANVAS',
title: t('toast.sentToCanvas'),
@@ -57,14 +43,7 @@ export const ImageMenuItemNewCanvasFromImageSubMenu = memo(() => {
const onClickNewCanvasWithRasterLayerFromImageWithResize = useCallback(async () => {
const { dispatch, getState } = store;
await navigationApi.focusPanel('canvas', WORKSPACE_PANEL_ID);
await newCanvasFromImage({
imageDTO,
withResize: true,
withInpaintMask: true,
type: 'raster_layer',
dispatch,
getState,
});
await newCanvasFromImage({ imageDTO, withResize: true, type: 'raster_layer', dispatch, getState });
toast({
id: 'SENT_TO_CANVAS',
title: t('toast.sentToCanvas'),
@@ -75,14 +54,7 @@ export const ImageMenuItemNewCanvasFromImageSubMenu = memo(() => {
const onClickNewCanvasWithControlLayerFromImageWithResize = useCallback(async () => {
const { dispatch, getState } = store;
await navigationApi.focusPanel('canvas', WORKSPACE_PANEL_ID);
await newCanvasFromImage({
imageDTO,
withResize: true,
withInpaintMask: true,
type: 'control_layer',
dispatch,
getState,
});
await newCanvasFromImage({ imageDTO, withResize: true, type: 'control_layer', dispatch, getState });
toast({
id: 'SENT_TO_CANVAS',
title: t('toast.sentToCanvas'),

View File

@@ -1,6 +1,5 @@
import { MenuItem } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { imageDTOToImageWithDims } from 'features/controlLayers/store/util';
import { useImageDTOContext } from 'features/gallery/contexts/ImageDTOContext';
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
import { toast } from 'features/toast/toast';
@@ -15,7 +14,7 @@ export const ImageMenuItemSendToUpscale = memo(() => {
const imageDTO = useImageDTOContext();
const handleSendToCanvas = useCallback(() => {
dispatch(upscaleInitialImageChanged(imageDTOToImageWithDims(imageDTO)));
dispatch(upscaleInitialImageChanged(imageDTO));
navigationApi.switchToTab('upscaling');
toast({
id: 'SENT_TO_CANVAS',

View File

@@ -1,23 +1,13 @@
import { objectEquals } from '@observ33r/object-equals';
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { isPlainObject } from 'es-toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import { uniq } from 'es-toolkit/compat';
import type { BoardRecordOrderBy } from 'services/api/types';
import { assert } from 'tsafe';
import {
type BoardId,
type ComparisonMode,
type GalleryState,
type GalleryView,
type OrderDir,
zGalleryState,
} from './types';
import type { BoardId, ComparisonMode, GalleryState, GalleryView, OrderDir } from './types';
const getInitialState = (): GalleryState => ({
const initialGalleryState: GalleryState = {
selection: [],
shouldAutoSwitch: true,
autoAssignBoardOnClick: true,
@@ -36,11 +26,11 @@ const getInitialState = (): GalleryState => ({
shouldShowArchivedBoards: false,
boardsListOrderBy: 'created_at',
boardsListOrderDir: 'DESC',
});
};
const slice = createSlice({
export const gallerySlice = createSlice({
name: 'gallery',
initialState: getInitialState(),
initialState: initialGalleryState,
reducers: {
imageSelected: (state, action: PayloadAction<string | null>) => {
// Let's be efficient here and not update the selection unless it has actually changed. This helps to prevent
@@ -197,22 +187,21 @@ export const {
searchTermChanged,
boardsListOrderByChanged,
boardsListOrderDirChanged,
} = slice.actions;
} = gallerySlice.actions;
export const selectGallerySlice = (state: RootState) => state.gallery;
export const gallerySliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zGalleryState,
getInitialState,
persistConfig: {
migrate: (state) => {
assert(isPlainObject(state));
if (!('_version' in state)) {
state._version = 1;
}
return zGalleryState.parse(state);
},
persistDenylist: ['selection', 'selectedBoardId', 'galleryView', 'imageToCompare'],
},
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrateGalleryState = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
return state;
};
export const galleryPersistConfig: PersistConfig<GalleryState> = {
name: gallerySlice.name,
initialState: initialGalleryState,
migrate: migrateGalleryState,
persistDenylist: ['selection', 'selectedBoardId', 'galleryView', 'imageToCompare'],
};

View File

@@ -1,13 +0,0 @@
import type { S } from 'services/api/types';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
import { describe, test } from 'vitest';
import type { BoardRecordOrderBy } from './types';
describe('Gallery Types', () => {
// Ensure zod types match OpenAPI types
test('BoardRecordOrderBy', () => {
assert<Equals<BoardRecordOrderBy, S['BoardRecordOrderBy']>>();
});
});

View File

@@ -1,41 +1,31 @@
import type { ImageCategory } from 'services/api/types';
import z from 'zod';
const zGalleryView = z.enum(['images', 'assets']);
export type GalleryView = z.infer<typeof zGalleryView>;
const zBoardId = z.union([z.literal('none'), z.intersection(z.string(), z.record(z.never(), z.never()))]);
export type BoardId = z.infer<typeof zBoardId>;
const zComparisonMode = z.enum(['slider', 'side-by-side', 'hover']);
export type ComparisonMode = z.infer<typeof zComparisonMode>;
const zComparisonFit = z.enum(['contain', 'fill']);
export type ComparisonFit = z.infer<typeof zComparisonFit>;
const zOrderDir = z.enum(['ASC', 'DESC']);
export type OrderDir = z.infer<typeof zOrderDir>;
const zBoardRecordOrderBy = z.enum(['created_at', 'board_name']);
export type BoardRecordOrderBy = z.infer<typeof zBoardRecordOrderBy>;
import type { BoardRecordOrderBy, ImageCategory } from 'services/api/types';
export const IMAGE_CATEGORIES: ImageCategory[] = ['general'];
export const ASSETS_CATEGORIES: ImageCategory[] = ['control', 'mask', 'user', 'other'];
export const zGalleryState = z.object({
selection: z.array(z.string()),
shouldAutoSwitch: z.boolean(),
autoAssignBoardOnClick: z.boolean(),
autoAddBoardId: zBoardId,
galleryImageMinimumWidth: z.number(),
selectedBoardId: zBoardId,
galleryView: zGalleryView,
boardSearchText: z.string(),
starredFirst: z.boolean(),
orderDir: zOrderDir,
searchTerm: z.string(),
alwaysShowImageSizeBadge: z.boolean(),
imageToCompare: z.string().nullable(),
comparisonMode: zComparisonMode,
comparisonFit: zComparisonFit,
shouldShowArchivedBoards: z.boolean(),
boardsListOrderBy: zBoardRecordOrderBy,
boardsListOrderDir: zOrderDir,
});
export type GalleryView = 'images' | 'assets';
export type BoardId = 'none' | (string & Record<never, never>);
export type ComparisonMode = 'slider' | 'side-by-side' | 'hover';
export type ComparisonFit = 'contain' | 'fill';
export type OrderDir = 'ASC' | 'DESC';
export type GalleryState = z.infer<typeof zGalleryState>;
export type GalleryState = {
selection: string[];
shouldAutoSwitch: boolean;
autoAssignBoardOnClick: boolean;
autoAddBoardId: BoardId;
galleryImageMinimumWidth: number;
selectedBoardId: BoardId;
galleryView: GalleryView;
boardSearchText: string;
starredFirst: boolean;
orderDir: OrderDir;
searchTerm: string;
alwaysShowImageSizeBadge: boolean;
imageToCompare: string | null;
comparisonMode: ComparisonMode;
comparisonFit: ComparisonFit;
shouldShowArchivedBoards: boolean;
boardsListOrderBy: BoardRecordOrderBy;
boardsListOrderDir: OrderDir;
};

View File

@@ -58,7 +58,7 @@ export const setRegionalGuidanceReferenceImage = (arg: {
export const setUpscaleInitialImage = (arg: { imageDTO: ImageDTO; dispatch: AppDispatch }) => {
const { imageDTO, dispatch } = arg;
dispatch(upscaleInitialImageChanged(imageDTOToImageWithDims(imageDTO)));
dispatch(upscaleInitialImageChanged(imageDTO));
};
export const setNodeImageFieldImage = (arg: {

View File

@@ -89,7 +89,6 @@ import { t } from 'i18next';
import type { ComponentType } from 'react';
import { useCallback, useEffect, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { imagesApi } from 'services/api/endpoints/images';
import { modelsApi } from 'services/api/endpoints/models';
import type { AnyModelConfig, ModelType } from 'services/api/types';
import { assert } from 'tsafe';
@@ -788,55 +787,11 @@ const LoRAs: CollectionMetadataHandler<LoRA[]> = {
const CanvasLayers: SingleMetadataHandler<CanvasMetadata> = {
[SingleMetadataKey]: true,
type: 'CanvasLayers',
parse: async (metadata, store) => {
parse: async (metadata) => {
const raw = getProperty(metadata, 'canvas_v2_metadata');
// This validator fetches all referenced images. If any do not exist, validation fails. The logic for this is in
// the zImageWithDims schema.
const parsed = await zCanvasMetadata.parseAsync(raw);
for (const entity of parsed.controlLayers) {
if (entity.controlAdapter.model) {
await throwIfModelDoesNotExist(entity.controlAdapter.model.key, store);
}
for (const object of entity.objects) {
if (object.type === 'image' && 'image_name' in object.image) {
await throwIfImageDoesNotExist(object.image.image_name, store);
}
}
}
for (const entity of parsed.inpaintMasks) {
for (const object of entity.objects) {
if (object.type === 'image' && 'image_name' in object.image) {
await throwIfImageDoesNotExist(object.image.image_name, store);
}
}
}
for (const entity of parsed.rasterLayers) {
for (const object of entity.objects) {
if (object.type === 'image' && 'image_name' in object.image) {
await throwIfImageDoesNotExist(object.image.image_name, store);
}
}
}
for (const entity of parsed.regionalGuidance) {
for (const object of entity.objects) {
if (object.type === 'image' && 'image_name' in object.image) {
await throwIfImageDoesNotExist(object.image.image_name, store);
}
}
for (const refImage of entity.referenceImages) {
if (refImage.config.image) {
await throwIfImageDoesNotExist(refImage.config.image.image_name, store);
}
if (refImage.config.model) {
await throwIfModelDoesNotExist(refImage.config.model.key, store);
}
}
}
return Promise.resolve(parsed);
},
recall: (value, store) => {
@@ -869,39 +824,27 @@ const CanvasLayers: SingleMetadataHandler<CanvasMetadata> = {
const RefImages: CollectionMetadataHandler<RefImageState[]> = {
[CollectionMetadataKey]: true,
type: 'RefImages',
parse: async (metadata, store) => {
let parsed: RefImageState[] | null = null;
parse: async (metadata) => {
try {
// First attempt to parse from the v6 slot
const raw = getProperty(metadata, 'ref_images');
parsed = z.array(zRefImageState).parse(raw);
// This validator fetches all referenced images. If any do not exist, validation fails. The logic for this is in
// the zImageWithDims schema.
const parsed = await z.array(zRefImageState).parseAsync(raw);
return Promise.resolve(parsed);
} catch {
// Fall back to extracting from canvas metadata]
const raw = getProperty(metadata, 'canvas_v2_metadata.referenceImages.entities');
// This validator fetches all referenced images. If any do not exist, validation fails. The logic for this is in
// the zImageWithDims schema.
const oldParsed = await z.array(zCanvasReferenceImageState_OLD).parseAsync(raw);
parsed = oldParsed.map(({ id, ipAdapter, isEnabled }) => ({
const parsed: RefImageState[] = oldParsed.map(({ id, ipAdapter, isEnabled }) => ({
id,
config: ipAdapter,
isEnabled,
}));
return parsed;
}
if (!parsed) {
throw new Error('No valid reference images found in metadata');
}
for (const refImage of parsed) {
if (refImage.config.image) {
await throwIfImageDoesNotExist(refImage.config.image.image_name, store);
}
if (refImage.config.model) {
await throwIfModelDoesNotExist(refImage.config.model.key, store);
}
}
return parsed;
},
recall: (value, store) => {
const entities = value.map((data) => ({ ...data, id: getPrefixedId('reference_image') }));
@@ -1298,19 +1241,3 @@ const isCompatibleWithMainModel = (candidate: ModelIdentifierField, store: AppSt
}
return candidate.base === base;
};
const throwIfImageDoesNotExist = async (name: string, store: AppStore): Promise<void> => {
try {
await store.dispatch(imagesApi.endpoints.getImageDTO.initiate(name, { subscribe: false })).unwrap();
} catch {
throw new Error(`Image with name ${name} does not exist`);
}
};
const throwIfModelDoesNotExist = async (key: string, store: AppStore): Promise<void> => {
try {
await store.dispatch(modelsApi.endpoints.getModelConfig.initiate(key, { subscribe: false }));
} catch {
throw new Error(`Model with key ${key} does not exist`);
}
};

View File

@@ -1,6 +1,7 @@
import { getStore } from 'app/store/nanostores/store';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import { modelsApi } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
import type { AnyModelConfig, BaseModelType, ModelType } from 'services/api/types';
/**
* Raised when a model config is unable to be fetched.
@@ -46,6 +47,45 @@ const fetchModelConfig = async (key: string): Promise<AnyModelConfig> => {
}
};
/**
* Fetches the model config for a given model name, base model, and model type. This provides backwards compatibility
* for MM1 model identifiers.
* @param name The model name.
* @param base The base model.
* @param type The model type.
* @returns A promise that resolves to the model config.
* @throws {ModelConfigNotFoundError} If the model config is unable to be fetched.
*/
const fetchModelConfigByAttrs = async (name: string, base: BaseModelType, type: ModelType): Promise<AnyModelConfig> => {
const { dispatch } = getStore();
try {
const req = dispatch(
modelsApi.endpoints.getModelConfigByAttrs.initiate({ name, base, type }, { subscribe: false })
);
return await req.unwrap();
} catch {
throw new ModelConfigNotFoundError(`Unable to retrieve model config for name/base/type ${name}/${base}/${type}`);
}
};
/**
* Fetches the model config given an identifier. First attempts to fetch by key, then falls back to fetching by attrs.
* @param identifier The model identifier.
* @returns A promise that resolves to the model config.
* @throws {ModelConfigNotFoundError} If the model config is unable to be fetched.
*/
export const fetchModelConfigByIdentifier = async (identifier: ModelIdentifierField): Promise<AnyModelConfig> => {
try {
return await fetchModelConfig(identifier.key);
} catch {
try {
return await fetchModelConfigByAttrs(identifier.name, identifier.base, identifier.type);
} catch {
throw new ModelConfigNotFoundError(`Unable to retrieve model config for identifier ${identifier}`);
}
}
};
/**
* Fetches the model config for a given model key and type, and ensures that the model config is of a specific type.
* @param key The model key.

View File

@@ -1,28 +1,21 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit';
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { isPlainObject } from 'es-toolkit';
import { zModelType } from 'features/nodes/types/common';
import { assert } from 'tsafe';
import z from 'zod';
import type { PersistConfig, RootState } from 'app/store/store';
import type { ModelType } from 'services/api/types';
const zFilterableModelType = zModelType.exclude(['onnx']).or(z.literal('refiner'));
export type FilterableModelType = z.infer<typeof zFilterableModelType>;
export type FilterableModelType = Exclude<ModelType, 'onnx'> | 'refiner';
const zModelManagerState = z.object({
_version: z.literal(1),
selectedModelKey: z.string().nullable(),
selectedModelMode: z.enum(['edit', 'view']),
searchTerm: z.string(),
filteredModelType: zFilterableModelType.nullable(),
scanPath: z.string().optional(),
shouldInstallInPlace: z.boolean(),
});
type ModelManagerState = {
_version: 1;
selectedModelKey: string | null;
selectedModelMode: 'edit' | 'view';
searchTerm: string;
filteredModelType: FilterableModelType | null;
scanPath: string | undefined;
shouldInstallInPlace: boolean;
};
type ModelManagerState = z.infer<typeof zModelManagerState>;
const getInitialState = (): ModelManagerState => ({
const initialModelManagerState: ModelManagerState = {
_version: 1,
selectedModelKey: null,
selectedModelMode: 'view',
@@ -30,11 +23,11 @@ const getInitialState = (): ModelManagerState => ({
searchTerm: '',
scanPath: undefined,
shouldInstallInPlace: true,
});
};
const slice = createSlice({
export const modelManagerV2Slice = createSlice({
name: 'modelmanagerV2',
initialState: getInitialState(),
initialState: initialModelManagerState,
reducers: {
setSelectedModelKey: (state, action: PayloadAction<string | null>) => {
state.selectedModelMode = 'view';
@@ -65,22 +58,21 @@ export const {
setSelectedModelMode,
setScanPath,
shouldInstallInPlaceChanged,
} = slice.actions;
} = modelManagerV2Slice.actions;
export const modelManagerSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zModelManagerState,
getInitialState,
persistConfig: {
migrate: (state) => {
assert(isPlainObject(state));
if (!('_version' in state)) {
state._version = 1;
}
return zModelManagerState.parse(state);
},
persistDenylist: ['selectedModelKey', 'selectedModelMode', 'filteredModelType', 'searchTerm'],
},
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrateModelManagerState = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
return state;
};
export const modelManagerV2PersistConfig: PersistConfig<ModelManagerState> = {
name: modelManagerV2Slice.name,
initialState: initialModelManagerState,
migrate: migrateModelManagerState,
persistDenylist: ['selectedModelKey', 'selectedModelMode', 'filteredModelType', 'searchTerm'],
};
export const selectModelManagerV2Slice = (state: RootState) => state.modelmanagerV2;

View File

@@ -14,13 +14,7 @@ import type {
ReactFlowProps,
ReactFlowState,
} from '@xyflow/react';
import {
Background,
ReactFlow,
SelectionMode,
useStore as useReactFlowStore,
useUpdateNodeInternals,
} from '@xyflow/react';
import { Background, ReactFlow, useStore as useReactFlowStore, useUpdateNodeInternals } from '@xyflow/react';
import { useAppDispatch, useAppSelector, useAppStore } from 'app/store/storeHooks';
import { useFocusRegion, useIsRegionFocused } from 'common/hooks/focus';
import { $isSelectingOutputNode, $outputNodeId } from 'features/nodes/components/sidePanel/workflow/publish';
@@ -262,7 +256,7 @@ export const Flow = memo(() => {
style={flowStyles}
onPaneClick={handlePaneClick}
deleteKeyCode={null}
selectionMode={selectionMode === 'full' ? SelectionMode.Full : SelectionMode.Partial}
selectionMode={selectionMode}
elevateEdgesOnSelect
nodeDragThreshold={1}
noDragClassName={NO_DRAG_CLASS}

View File

@@ -11,15 +11,14 @@ import type {
XYPosition,
} from '@xyflow/react';
import { applyEdgeChanges, applyNodeChanges, getConnectedEdges, getIncomers, getOutgoers } from '@xyflow/react';
import type { SliceConfig } from 'app/store/types';
import type { PersistConfig } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
import { isPlainObject } from 'es-toolkit';
import {
addElement,
removeElement,
reparentElement,
} from 'features/nodes/components/sidePanel/builder/form-manipulation';
import { type NodesState, zNodesState } from 'features/nodes/store/types';
import type { NodesState } from 'features/nodes/store/types';
import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants';
import type {
BoardFieldValue,
@@ -128,7 +127,6 @@ import {
import { atom, computed } from 'nanostores';
import type { MouseEvent } from 'react';
import type { UndoableOptions } from 'redux-undo';
import { assert } from 'tsafe';
import type { z } from 'zod';
import type { PendingConnection, Templates } from './types';
@@ -153,11 +151,11 @@ export const getInitialWorkflow = (): Omit<NodesState, 'mode' | 'formFieldInitia
};
};
const getInitialState = (): NodesState => ({
const initialState: NodesState = {
_version: 1,
formFieldInitialValues: {},
...getInitialWorkflow(),
});
};
type FieldValueAction<T extends FieldValue> = PayloadAction<{
nodeId: string;
@@ -210,9 +208,9 @@ const fieldValueReducer = <T extends FieldValue>(
field.value = result.data;
};
const slice = createSlice({
export const nodesSlice = createSlice({
name: 'nodes',
initialState: getInitialState(),
initialState: initialState,
reducers: {
nodesChanged: (state, action: PayloadAction<NodeChange<AnyNode>[]>) => {
// In v12.7.0, @xyflow/react added a `domAttributes` property to the node data. One DOM attribute is
@@ -590,7 +588,7 @@ const slice = createSlice({
}
node.data.notes = value;
},
nodeEditorReset: () => getInitialState(),
nodeEditorReset: () => deepClone(initialState),
workflowNameChanged: (state, action: PayloadAction<string>) => {
state.name = action.payload;
},
@@ -675,7 +673,7 @@ const slice = createSlice({
const formFieldInitialValues = getFormFieldInitialValues(workflowExtra.form, nodes);
return {
...getInitialState(),
...deepClone(initialState),
...deepClone(workflowExtra),
formFieldInitialValues,
nodes: nodes.map((node) => ({ ...SHARED_NODE_PROPERTIES, ...node })),
@@ -760,7 +758,7 @@ export const {
workflowLoaded,
undo,
redo,
} = slice.actions;
} = nodesSlice.actions;
export const $cursorPos = atom<XYPosition | null>(null);
export const $templates = atom<Templates>({});
@@ -777,6 +775,21 @@ export const $lastEdgeUpdateMouseEvent = atom<MouseEvent | null>(null);
export const $viewport = atom<Viewport>({ x: 0, y: 0, zoom: 1 });
export const $addNodeCmdk = atom(false);
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrateNodesState = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
return state;
};
export const nodesPersistConfig: PersistConfig<NodesState> = {
name: nodesSlice.name,
initialState: initialState,
migrate: migrateNodesState,
persistDenylist: [],
};
type NodeSelectionAction = {
type: ReturnType<typeof nodesChanged>['type'];
payload: NodeSelectionChange[];
@@ -880,10 +893,10 @@ const isHighFrequencyWorkflowDetailsAction = isAnyOf(
// a note in a notes node, we don't want to create a new undo group for every keystroke.
const isHighFrequencyNodeScopedAction = isAnyOf(nodeLabelChanged, nodeNotesChanged, notesNodeValueChanged);
const reduxUndoOptions: UndoableOptions<NodesState, UnknownAction> = {
export const nodesUndoableConfig: UndoableOptions<NodesState, UnknownAction> = {
limit: 64,
undoType: slice.actions.undo.type,
redoType: slice.actions.redo.type,
undoType: nodesSlice.actions.undo.type,
redoType: nodesSlice.actions.redo.type,
groupBy: (action, _state, _history) => {
if (isHighFrequencyFieldChangeAction(action)) {
// Group by type, node id and field name
@@ -915,7 +928,7 @@ const reduxUndoOptions: UndoableOptions<NodesState, UnknownAction> = {
},
filter: (action, _state, _history) => {
// Ignore all actions from other slices
if (!action.type.startsWith(slice.name)) {
if (!action.type.startsWith(nodesSlice.name)) {
return false;
}
// Ignore actions that only select or deselect nodes and edges
@@ -930,24 +943,6 @@ const reduxUndoOptions: UndoableOptions<NodesState, UnknownAction> = {
},
};
export const nodesSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zNodesState,
getInitialState,
persistConfig: {
migrate: (state) => {
assert(isPlainObject(state));
if (!('_version' in state)) {
state._version = 1;
}
return zNodesState.parse(state);
},
},
undoableConfig: {
reduxUndoOptions,
},
};
// The form builder's initial values are based on the current values of the node fields in the workflow.
export const getFormFieldInitialValues = (form: BuilderForm, nodes: NodesState['nodes']) => {
const formFieldInitialValues: Record<string, StatefulFieldValue> = {};

View File

@@ -1,8 +1,7 @@
import type { HandleType } from '@xyflow/react';
import { type FieldInputTemplate, type FieldOutputTemplate, zStatefulFieldValue } from 'features/nodes/types/field';
import { type InvocationTemplate, type NodeExecutionState, zAnyEdge, zAnyNode } from 'features/nodes/types/invocation';
import { zWorkflowV3 } from 'features/nodes/types/workflow';
import z from 'zod';
import type { FieldInputTemplate, FieldOutputTemplate, StatefulFieldValue } from 'features/nodes/types/field';
import type { AnyEdge, AnyNode, InvocationTemplate, NodeExecutionState } from 'features/nodes/types/invocation';
import type { WorkflowV3 } from 'features/nodes/types/workflow';
export type Templates = Record<string, InvocationTemplate>;
export type NodeExecutionStates = Record<string, NodeExecutionState | undefined>;
@@ -14,13 +13,11 @@ export type PendingConnection = {
fieldTemplate: FieldInputTemplate | FieldOutputTemplate;
};
export const zWorkflowMode = z.enum(['edit', 'view']);
export type WorkflowMode = z.infer<typeof zWorkflowMode>;
export const zNodesState = z.object({
_version: z.literal(1),
nodes: z.array(zAnyNode),
edges: z.array(zAnyEdge),
formFieldInitialValues: z.record(z.string(), zStatefulFieldValue),
...zWorkflowV3.omit({ nodes: true, edges: true, is_published: true }).shape,
});
export type NodesState = z.infer<typeof zNodesState>;
export type WorkflowMode = 'edit' | 'view';
export type NodesState = {
_version: 1;
nodes: AnyNode[];
edges: AnyEdge[];
formFieldInitialValues: Record<string, StatefulFieldValue>;
} & Omit<WorkflowV3, 'nodes' | 'edges' | 'is_published'>;

View File

@@ -1,43 +1,34 @@
import type { PayloadAction, Selector } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit';
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { type WorkflowMode, zWorkflowMode } from 'features/nodes/store/types';
import type { PersistConfig, RootState } from 'app/store/store';
import type { WorkflowMode } from 'features/nodes/store/types';
import type { WorkflowCategory } from 'features/nodes/types/workflow';
import { atom, computed } from 'nanostores';
import {
type SQLiteDirection,
type WorkflowRecordOrderBy,
zSQLiteDirection,
zWorkflowRecordOrderBy,
} from 'services/api/types';
import z from 'zod';
import type { SQLiteDirection, WorkflowRecordOrderBy } from 'services/api/types';
const zWorkflowLibraryView = z.enum(['recent', 'yours', 'private', 'shared', 'defaults', 'published']);
export type WorkflowLibraryView = z.infer<typeof zWorkflowLibraryView>;
export type WorkflowLibraryView = 'recent' | 'yours' | 'private' | 'shared' | 'defaults' | 'published';
const zWorkflowLibraryState = z.object({
mode: zWorkflowMode,
view: zWorkflowLibraryView,
orderBy: zWorkflowRecordOrderBy,
direction: zSQLiteDirection,
searchTerm: z.string(),
selectedTags: z.array(z.string()),
});
type WorkflowLibraryState = z.infer<typeof zWorkflowLibraryState>;
type WorkflowLibraryState = {
mode: WorkflowMode;
view: WorkflowLibraryView;
orderBy: WorkflowRecordOrderBy;
direction: SQLiteDirection;
searchTerm: string;
selectedTags: string[];
};
const getInitialState = (): WorkflowLibraryState => ({
const initialWorkflowLibraryState: WorkflowLibraryState = {
mode: 'view',
searchTerm: '',
orderBy: 'opened_at',
direction: 'DESC',
selectedTags: [],
view: 'defaults',
});
};
const slice = createSlice({
export const workflowLibrarySlice = createSlice({
name: 'workflowLibrary',
initialState: getInitialState(),
initialState: initialWorkflowLibraryState,
reducers: {
workflowModeChanged: (state, action: PayloadAction<WorkflowMode>) => {
state.mode = action.payload;
@@ -82,15 +73,16 @@ export const {
workflowLibraryTagToggled,
workflowLibraryTagsReset,
workflowLibraryViewChanged,
} = slice.actions;
} = workflowLibrarySlice.actions;
export const workflowLibrarySliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zWorkflowLibraryState,
getInitialState,
persistConfig: {
migrate: (state) => zWorkflowLibraryState.parse(state),
},
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrateWorkflowLibraryState = (state: any): any => state;
export const workflowLibraryPersistConfig: PersistConfig<WorkflowLibraryState> = {
name: workflowLibrarySlice.name,
initialState: initialWorkflowLibraryState,
migrate: migrateWorkflowLibraryState,
persistDenylist: [],
};
const selectWorkflowLibrarySlice = (state: RootState) => state.workflowLibrary;

View File

@@ -1,10 +1,8 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit';
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { isPlainObject } from 'es-toolkit';
import { SelectionMode } from '@xyflow/react';
import type { PersistConfig, RootState } from 'app/store/store';
import type { Selector } from 'react-redux';
import { assert } from 'tsafe';
import z from 'zod';
export const zLayeringStrategy = z.enum(['network-simplex', 'longest-path']);
@@ -13,28 +11,25 @@ export const zLayoutDirection = z.enum(['TB', 'LR']);
type LayoutDirection = z.infer<typeof zLayoutDirection>;
export const zNodeAlignment = z.enum(['UL', 'UR', 'DL', 'DR']);
type NodeAlignment = z.infer<typeof zNodeAlignment>;
const zSelectionMode = z.enum(['partial', 'full']);
const zWorkflowSettingsState = z.object({
_version: z.literal(1),
shouldShowMinimapPanel: z.boolean(),
layeringStrategy: zLayeringStrategy,
nodeSpacing: z.number(),
layerSpacing: z.number(),
layoutDirection: zLayoutDirection,
shouldValidateGraph: z.boolean(),
shouldAnimateEdges: z.boolean(),
nodeAlignment: zNodeAlignment,
nodeOpacity: z.number(),
shouldSnapToGrid: z.boolean(),
shouldColorEdges: z.boolean(),
shouldShowEdgeLabels: z.boolean(),
selectionMode: zSelectionMode,
});
export type WorkflowSettingsState = {
_version: 1;
shouldShowMinimapPanel: boolean;
layeringStrategy: LayeringStrategy;
nodeSpacing: number;
layerSpacing: number;
layoutDirection: LayoutDirection;
shouldValidateGraph: boolean;
shouldAnimateEdges: boolean;
nodeAlignment: NodeAlignment;
nodeOpacity: number;
shouldSnapToGrid: boolean;
shouldColorEdges: boolean;
shouldShowEdgeLabels: boolean;
selectionMode: SelectionMode;
};
export type WorkflowSettingsState = z.infer<typeof zWorkflowSettingsState>;
const getInitialState = (): WorkflowSettingsState => ({
const initialState: WorkflowSettingsState = {
_version: 1,
shouldShowMinimapPanel: true,
layeringStrategy: 'network-simplex',
@@ -48,12 +43,12 @@ const getInitialState = (): WorkflowSettingsState => ({
shouldColorEdges: true,
shouldShowEdgeLabels: false,
nodeOpacity: 1,
selectionMode: 'partial',
});
selectionMode: SelectionMode.Partial,
};
const slice = createSlice({
export const workflowSettingsSlice = createSlice({
name: 'workflowSettings',
initialState: getInitialState(),
initialState,
reducers: {
shouldShowMinimapPanelChanged: (state, action: PayloadAction<boolean>) => {
state.shouldShowMinimapPanel = action.payload;
@@ -92,7 +87,7 @@ const slice = createSlice({
state.nodeAlignment = action.payload;
},
selectionModeChanged: (state, action: PayloadAction<boolean>) => {
state.selectionMode = action.payload ? 'full' : 'partial';
state.selectionMode = action.payload ? SelectionMode.Full : SelectionMode.Partial;
},
},
});
@@ -111,21 +106,21 @@ export const {
shouldValidateGraphChanged,
nodeOpacityChanged,
selectionModeChanged,
} = slice.actions;
} = workflowSettingsSlice.actions;
export const workflowSettingsSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zWorkflowSettingsState,
getInitialState,
persistConfig: {
migrate: (state) => {
assert(isPlainObject(state));
if (!('_version' in state)) {
state._version = 1;
}
return zWorkflowSettingsState.parse(state);
},
},
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrateWorkflowSettingsState = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
return state;
};
export const workflowSettingsPersistConfig: PersistConfig<WorkflowSettingsState> = {
name: workflowSettingsSlice.name,
initialState,
migrate: migrateWorkflowSettingsState,
persistDenylist: [],
};
export const selectWorkflowSettingsSlice = (state: RootState) => state.workflowSettings;

View File

@@ -92,7 +92,7 @@ export const zMainModelBase = z.enum([
]);
type MainModelBase = z.infer<typeof zMainModelBase>;
export const isMainModelBase = (base: unknown): base is MainModelBase => zMainModelBase.safeParse(base).success;
export const zModelType = z.enum([
const zModelType = z.enum([
'main',
'vae',
'lora',

View File

@@ -43,7 +43,7 @@ export const zNotesNodeData = z.object({
isOpen: z.boolean(),
notes: z.string(),
});
const zCurrentImageNodeData = z.object({
const _zCurrentImageNodeData = z.object({
id: z.string().trim().min(1),
type: z.literal('current_image'),
label: z.string(),
@@ -52,35 +52,12 @@ const zCurrentImageNodeData = z.object({
export type NotesNodeData = z.infer<typeof zNotesNodeData>;
export type InvocationNodeData = z.infer<typeof zInvocationNodeData>;
type CurrentImageNodeData = z.infer<typeof zCurrentImageNodeData>;
type CurrentImageNodeData = z.infer<typeof _zCurrentImageNodeData>;
const zInvocationNodeValidationSchema = z.looseObject({
type: z.literal('invocation'),
data: zInvocationNodeData,
});
const zInvocationNode = z.custom<Node<InvocationNodeData, 'invocation'>>(
(val) => zInvocationNodeValidationSchema.safeParse(val).success
);
export type InvocationNode = z.infer<typeof zInvocationNode>;
const zNotesNodeValidationSchema = z.looseObject({
type: z.literal('notes'),
data: zNotesNodeData,
});
const zNotesNode = z.custom<Node<NotesNodeData, 'notes'>>((val) => zNotesNodeValidationSchema.safeParse(val).success);
export type NotesNode = z.infer<typeof zNotesNode>;
const zCurrentImageNodeValidationSchema = z.looseObject({
type: z.literal('current_image'),
data: zCurrentImageNodeData,
});
const zCurrentImageNode = z.custom<Node<CurrentImageNodeData, 'current_image'>>(
(val) => zCurrentImageNodeValidationSchema.safeParse(val).success
);
export type CurrentImageNode = z.infer<typeof zCurrentImageNode>;
export const zAnyNode = z.union([zInvocationNode, zNotesNode, zCurrentImageNode]);
export type AnyNode = z.infer<typeof zAnyNode>;
export type InvocationNode = Node<InvocationNodeData, 'invocation'>;
export type NotesNode = Node<NotesNodeData, 'notes'>;
export type CurrentImageNode = Node<CurrentImageNodeData, 'current_image'>;
export type AnyNode = InvocationNode | NotesNode | CurrentImageNode;
export const isInvocationNode = (node?: AnyNode | null): node is InvocationNode =>
Boolean(node && node.type === 'invocation');
@@ -106,29 +83,13 @@ export type NodeExecutionState = z.infer<typeof _zNodeExecutionState>;
// #endregion
// #region Edges
const zDefaultInvocationNodeEdgeValidationSchema = z.looseObject({
type: z.literal('default'),
});
const zDefaultInvocationNodeEdge = z.custom<Edge<Record<string, never>, 'default'>>(
(val) => zDefaultInvocationNodeEdgeValidationSchema.safeParse(val).success
);
export type DefaultInvocationNodeEdge = z.infer<typeof zDefaultInvocationNodeEdge>;
const zInvocationNodeEdgeCollapsedData = z.object({
const _zInvocationNodeEdgeCollapsedData = z.object({
count: z.number().int().min(1),
});
const zInvocationNodeEdgeCollapsedValidationSchema = z.looseObject({
type: z.literal('default'),
data: zInvocationNodeEdgeCollapsedData,
});
type InvocationNodeEdgeCollapsedData = z.infer<typeof zInvocationNodeEdgeCollapsedData>;
const zCollapsedInvocationNodeEdge = z.custom<Edge<InvocationNodeEdgeCollapsedData, 'collapsed'>>(
(val) => zInvocationNodeEdgeCollapsedValidationSchema.safeParse(val).success
);
export type CollapsedInvocationNodeEdge = z.infer<typeof zCollapsedInvocationNodeEdge>;
export const zAnyEdge = z.union([zDefaultInvocationNodeEdge, zCollapsedInvocationNodeEdge]);
export type AnyEdge = z.infer<typeof zAnyEdge>;
type InvocationNodeEdgeCollapsedData = z.infer<typeof _zInvocationNodeEdgeCollapsedData>;
export type DefaultInvocationNodeEdge = Edge<Record<string, never>, 'default'>;
export type CollapsedInvocationNodeEdge = Edge<InvocationNodeEdgeCollapsedData, 'collapsed'>;
export type AnyEdge = DefaultInvocationNodeEdge | CollapsedInvocationNodeEdge;
// #endregion
export const isBatchNodeType = (type: string) =>

View File

@@ -4,7 +4,6 @@ import { range } from 'es-toolkit/compat';
import type { SeedBehaviour } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import { API_BASE_MODELS } from 'features/parameters/types/constants';
import type { components } from 'services/api/schema';
import type { Batch, EnqueueBatchArg, Invocation } from 'services/api/types';
import { assert } from 'tsafe';
@@ -19,7 +18,7 @@ const getExtendedPrompts = (arg: {
// Normally, the seed behaviour implicity determines the batch size. But when we use models without seeds (like
// ChatGPT 4o) in conjunction with the per-prompt seed behaviour, we lose out on that implicit batch size. To rectify
// this, we need to create a batch of the right size by repeating the prompts.
if (seedBehaviour === 'PER_PROMPT' || API_BASE_MODELS.includes(model.base)) {
if (seedBehaviour === 'PER_PROMPT' || model.base === 'chatgpt-4o' || model.base === 'flux-kontext') {
return range(iterations).flatMap(() => prompts);
}
return prompts;

View File

@@ -1,5 +1,4 @@
import { FormControl, FormLabel } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { selectBase } from 'features/controlLayers/store/paramsSlice';
@@ -7,35 +6,13 @@ import { ModelPicker } from 'features/parameters/components/ModelPicker';
import { selectTileControlNetModel, tileControlnetModelChanged } from 'features/parameters/store/upscaleSlice';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/api/endpoints/models';
import { useControlNetModels } from 'services/api/hooks/modelsByType';
import { type ControlNetModelConfig, isControlNetModelConfig } from 'services/api/types';
const selectTileControlNetModelConfig = createSelector(
selectModelConfigsQuery,
selectTileControlNetModel,
(modelConfigs, modelIdentifierField) => {
if (!modelConfigs.data) {
return null;
}
if (!modelIdentifierField) {
return null;
}
const modelConfig = modelConfigsAdapterSelectors.selectById(modelConfigs.data, modelIdentifierField.key);
if (!modelConfig) {
return null;
}
if (!isControlNetModelConfig(modelConfig)) {
return null;
}
return modelConfig;
}
);
import type { ControlNetModelConfig } from 'services/api/types';
const ParamTileControlNetModel = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const tileControlNetModel = useAppSelector(selectTileControlNetModelConfig);
const tileControlNetModel = useAppSelector(selectTileControlNetModel);
const currentBaseModel = useAppSelector(selectBase);
const [modelConfigs, { isLoading }] = useControlNetModels();

View File

@@ -1,21 +1,21 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import type { ImageWithDims } from 'features/controlLayers/store/types';
import { selectUpscaleSlice } from 'features/parameters/store/upscaleSlice';
import { selectConfigSlice } from 'features/system/store/configSlice';
import { useMemo } from 'react';
import type { ImageDTO } from 'services/api/types';
const createIsTooLargeToUpscaleSelector = (imageWithDims?: ImageWithDims | null) =>
const createIsTooLargeToUpscaleSelector = (imageDTO?: ImageDTO | null) =>
createSelector(selectUpscaleSlice, selectConfigSlice, (upscale, config) => {
const { upscaleModel, scale } = upscale;
const { maxUpscaleDimension } = config;
if (!maxUpscaleDimension || !upscaleModel || !imageWithDims) {
if (!maxUpscaleDimension || !upscaleModel || !imageDTO) {
// When these are missing, another warning will be shown
return false;
}
const { width, height } = imageWithDims;
const { width, height } = imageDTO;
const maxPixels = maxUpscaleDimension ** 2;
const upscaledPixels = width * scale * height * scale;
@@ -23,7 +23,7 @@ const createIsTooLargeToUpscaleSelector = (imageWithDims?: ImageWithDims | null)
return upscaledPixels > maxPixels;
});
export const useIsTooLargeToUpscale = (imageWithDims?: ImageWithDims | null) => {
const selectIsTooLargeToUpscale = useMemo(() => createIsTooLargeToUpscaleSelector(imageWithDims), [imageWithDims]);
export const useIsTooLargeToUpscale = (imageDTO?: ImageDTO | null) => {
const selectIsTooLargeToUpscale = useMemo(() => createIsTooLargeToUpscaleSelector(imageDTO), [imageDTO]);
return useAppSelector(selectIsTooLargeToUpscale);
};

View File

@@ -1,33 +1,24 @@
import type { PayloadAction, Selector } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit';
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { isPlainObject } from 'es-toolkit';
import type { ImageWithDims } from 'features/controlLayers/store/types';
import { zImageWithDims } from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common';
import type { PersistConfig, RootState } from 'app/store/store';
import type { ParameterSpandrelImageToImageModel } from 'features/parameters/types/parameterSchemas';
import type { ControlNetModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import z from 'zod';
import type { ControlNetModelConfig, ImageDTO } from 'services/api/types';
const zUpscaleState = z.object({
_version: z.literal(2),
upscaleModel: zModelIdentifierField.nullable(),
upscaleInitialImage: zImageWithDims.nullable(),
structure: z.number(),
creativity: z.number(),
tileControlnetModel: zModelIdentifierField.nullable(),
scale: z.number(),
postProcessingModel: zModelIdentifierField.nullable(),
tileSize: z.number(),
tileOverlap: z.number(),
});
export interface UpscaleState {
_version: 1;
upscaleModel: ParameterSpandrelImageToImageModel | null;
upscaleInitialImage: ImageDTO | null;
structure: number;
creativity: number;
tileControlnetModel: ControlNetModelConfig | null;
scale: number;
postProcessingModel: ParameterSpandrelImageToImageModel | null;
tileSize: number;
tileOverlap: number;
}
export type UpscaleState = z.infer<typeof zUpscaleState>;
const getInitialState = (): UpscaleState => ({
_version: 2,
const initialUpscaleState: UpscaleState = {
_version: 1,
upscaleModel: null,
upscaleInitialImage: null,
structure: 0,
@@ -37,19 +28,16 @@ const getInitialState = (): UpscaleState => ({
postProcessingModel: null,
tileSize: 1024,
tileOverlap: 128,
});
};
const slice = createSlice({
export const upscaleSlice = createSlice({
name: 'upscale',
initialState: getInitialState(),
initialState: initialUpscaleState,
reducers: {
upscaleModelChanged: (state, action: PayloadAction<ParameterSpandrelImageToImageModel | null>) => {
const result = zUpscaleState.shape.upscaleModel.safeParse(action.payload);
if (result.success) {
state.upscaleModel = result.data;
}
state.upscaleModel = action.payload;
},
upscaleInitialImageChanged: (state, action: PayloadAction<ImageWithDims | null>) => {
upscaleInitialImageChanged: (state, action: PayloadAction<ImageDTO | null>) => {
state.upscaleInitialImage = action.payload;
},
structureChanged: (state, action: PayloadAction<number>) => {
@@ -59,19 +47,13 @@ const slice = createSlice({
state.creativity = action.payload;
},
tileControlnetModelChanged: (state, action: PayloadAction<ControlNetModelConfig | null>) => {
const result = zUpscaleState.shape.tileControlnetModel.safeParse(action.payload);
if (result.success) {
state.tileControlnetModel = result.data;
}
state.tileControlnetModel = action.payload;
},
scaleChanged: (state, action: PayloadAction<number>) => {
state.scale = action.payload;
},
postProcessingModelChanged: (state, action: PayloadAction<ParameterSpandrelImageToImageModel | null>) => {
const result = zUpscaleState.shape.postProcessingModel.safeParse(action.payload);
if (result.success) {
state.postProcessingModel = result.data;
}
state.postProcessingModel = action.payload;
},
tileSizeChanged: (state, action: PayloadAction<number>) => {
state.tileSize = action.payload;
@@ -92,33 +74,21 @@ export const {
postProcessingModelChanged,
tileSizeChanged,
tileOverlapChanged,
} = slice.actions;
} = upscaleSlice.actions;
export const upscaleSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zUpscaleState,
getInitialState,
persistConfig: {
migrate: (state) => {
assert(isPlainObject(state));
if (!('_version' in state)) {
state._version = 1;
}
if (state._version === 1) {
state._version = 2;
// Migrate from v1 to v2: upscaleInitialImage was an ImageDTO, now it's an ImageWithDims
if (state.upscaleInitialImage) {
const { image_name, width, height } = state.upscaleInitialImage;
state.upscaleInitialImage = {
image_name,
width,
height,
};
}
}
return zUpscaleState.parse(state);
},
},
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrateUpscaleState = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
return state;
};
export const upscalePersistConfig: PersistConfig<UpscaleState> = {
name: upscaleSlice.name,
initialState: initialUpscaleState,
migrate: migrateUpscaleState,
persistDenylist: [],
};
export const selectUpscaleSlice = (state: RootState) => state.upscale;

View File

@@ -13,13 +13,14 @@ export const CancelAllExceptCurrentButton = memo((props: ButtonProps) => {
<Button
isDisabled={api.isDisabled}
isLoading={api.isLoading}
aria-label={t('queue.clear')}
tooltip={t('queue.cancelAllExceptCurrentTooltip')}
leftIcon={<PiXCircle />}
colorScheme="error"
onClick={api.openDialog}
{...props}
>
{t('queue.cancelAllExceptCurrentTooltip')}
{t('queue.clear')}
</Button>
);
});

View File

@@ -1,29 +0,0 @@
import type { ButtonProps } from '@invoke-ai/ui-library';
import { Button } from '@invoke-ai/ui-library';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiTrashBold } from 'react-icons/pi';
import { useClearQueueDialog } from './ClearQueueConfirmationAlertDialog';
export const ClearQueueButton = memo((props: ButtonProps) => {
const { t } = useTranslation();
const api = useClearQueueDialog();
return (
<Button
isDisabled={api.isDisabled}
isLoading={api.isLoading}
aria-label={t('queue.clear')}
tooltip={t('queue.clearTooltip')}
leftIcon={<PiTrashBold />}
colorScheme="error"
onClick={api.openDialog}
{...props}
>
{t('queue.clear')}
</Button>
);
});
ClearQueueButton.displayName = 'ClearQueueButton';

View File

@@ -7,7 +7,7 @@ import { useTranslation } from 'react-i18next';
const [useClearQueueConfirmationAlertDialog] = buildUseBoolean(false);
export const useClearQueueDialog = () => {
const useClearQueueDialog = () => {
const dialog = useClearQueueConfirmationAlertDialog();
const clearQueue = useClearQueue();

View File

@@ -9,19 +9,15 @@ import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { navigationApi } from 'features/ui/layouts/navigation-api';
import { memo, useCallback, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import { PiListBold, PiPauseFill, PiPlayFill, PiQueueBold, PiTrashBold, PiXBold, PiXCircle } from 'react-icons/pi';
import { useClearQueueDialog } from './ClearQueueConfirmationAlertDialog';
import { PiListBold, PiPauseFill, PiPlayFill, PiQueueBold, PiXBold, PiXCircle } from 'react-icons/pi';
export const QueueActionsMenuButton = memo(() => {
const ref = useRef<HTMLDivElement>(null);
const { t } = useTranslation();
const isPauseEnabled = useFeatureStatus('pauseQueue');
const isResumeEnabled = useFeatureStatus('resumeQueue');
const isClearAllEnabled = useFeatureStatus('cancelAndClearAll');
const cancelAllExceptCurrent = useCancelAllExceptCurrentQueueItemDialog();
const cancelCurrentQueueItem = useCancelCurrentQueueItem();
const clearQueue = useClearQueueDialog();
const resumeProcessor = useResumeProcessor();
const pauseProcessor = usePauseProcessor();
const openQueue = useCallback(() => {
@@ -59,17 +55,6 @@ export const QueueActionsMenuButton = memo(() => {
>
{t('queue.cancelAllExceptCurrentTooltip')}
</MenuItem>
{isClearAllEnabled && (
<MenuItem
isDestructive
icon={<PiTrashBold />}
onClick={clearQueue.openDialog}
isLoading={clearQueue.isLoading}
isDisabled={clearQueue.isDisabled}
>
{t('queue.clearTooltip')}
</MenuItem>
)}
{isResumeEnabled && (
<MenuItem
icon={<PiPlayFill />}

View File

@@ -4,7 +4,6 @@ import { memo } from 'react';
import { CancelAllExceptCurrentButton } from './CancelAllExceptCurrentButton';
import ClearModelCacheButton from './ClearModelCacheButton';
import { ClearQueueButton } from './ClearQueueButton';
import PauseProcessorButton from './PauseProcessorButton';
import PruneQueueButton from './PruneQueueButton';
import ResumeProcessorButton from './ResumeProcessorButton';
@@ -12,20 +11,19 @@ import ResumeProcessorButton from './ResumeProcessorButton';
const QueueTabQueueControls = () => {
const isPauseEnabled = useFeatureStatus('pauseQueue');
const isResumeEnabled = useFeatureStatus('resumeQueue');
const isClearQueueEnabled = useFeatureStatus('cancelAndClearAll');
return (
<Flex flexDir="column" layerStyle="first" borderRadius="base" p={2} gap={2}>
<Flex gap={2}>
{(isPauseEnabled || isResumeEnabled) && (
<ButtonGroup orientation="vertical" size="sm">
<ButtonGroup w={28} orientation="vertical" size="sm">
{isResumeEnabled && <ResumeProcessorButton />}
{isPauseEnabled && <PauseProcessorButton />}
</ButtonGroup>
)}
<ButtonGroup orientation="vertical" size="sm">
<ButtonGroup w={28} orientation="vertical" size="sm">
<PruneQueueButton />
{isClearQueueEnabled ? <ClearQueueButton /> : <CancelAllExceptCurrentButton />}
<CancelAllExceptCurrentButton />
</ButtonGroup>
</Flex>
<ClearModelCacheButton />

View File

@@ -1,27 +1,24 @@
import type { PayloadAction, Selector } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit';
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import z from 'zod';
const zQueueState = z.object({
listCursor: z.number().optional(),
listPriority: z.number().optional(),
selectedQueueItem: z.string().optional(),
resumeProcessorOnEnqueue: z.boolean(),
});
type QueueState = z.infer<typeof zQueueState>;
interface QueueState {
listCursor: number | undefined;
listPriority: number | undefined;
selectedQueueItem: string | undefined;
resumeProcessorOnEnqueue: boolean;
}
const getInitialState = (): QueueState => ({
const initialQueueState: QueueState = {
listCursor: undefined,
listPriority: undefined,
selectedQueueItem: undefined,
resumeProcessorOnEnqueue: true,
});
};
const slice = createSlice({
export const queueSlice = createSlice({
name: 'queue',
initialState: getInitialState(),
initialState: initialQueueState,
reducers: {
listCursorChanged: (state, action: PayloadAction<number | undefined>) => {
state.listCursor = action.payload;
@@ -36,13 +33,7 @@ const slice = createSlice({
},
});
export const { listCursorChanged, listPriorityChanged, listParamsReset } = slice.actions;
export const queueSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zQueueState,
getInitialState,
};
export const { listCursorChanged, listPriorityChanged, listParamsReset } = queueSlice.actions;
const selectQueueSlice = (state: RootState) => state.queue;
const createQueueSelector = <T>(selector: Selector<QueueState, T>) => createSelector(selectQueueSlice, selector);

View File

@@ -1,7 +1,6 @@
import { Flex, Text } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { UploadImageIconButton } from 'common/hooks/useImageUploadButton';
import { imageDTOToImageWithDims } from 'features/controlLayers/store/util';
import type { SetUpscaleInitialImageDndTargetData } from 'features/dnd/dnd';
import { setUpscaleInitialImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
@@ -11,13 +10,11 @@ import { selectUpscaleInitialImage, upscaleInitialImageChanged } from 'features/
import { t } from 'i18next';
import { useCallback, useMemo } from 'react';
import { PiArrowCounterClockwiseBold } from 'react-icons/pi';
import { useImageDTO } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
export const UpscaleInitialImage = () => {
const dispatch = useAppDispatch();
const upscaleInitialImage = useAppSelector(selectUpscaleInitialImage);
const imageDTO = useImageDTO(upscaleInitialImage?.image_name);
const imageDTO = useAppSelector(selectUpscaleInitialImage);
const dndTargetData = useMemo<SetUpscaleInitialImageDndTargetData>(
() => setUpscaleInitialImageDndTarget.getData(),
[]
@@ -29,7 +26,7 @@ export const UpscaleInitialImage = () => {
const onUpload = useCallback(
(imageDTO: ImageDTO) => {
dispatch(upscaleInitialImageChanged(imageDTOToImageWithDims(imageDTO)));
dispatch(upscaleInitialImageChanged(imageDTO));
},
[dispatch]
);

View File

@@ -31,10 +31,8 @@ export const UpscaleWarning = () => {
const validModel = modelConfigs.find((cnetModel) => {
return cnetModel.base === model?.base && cnetModel.name.toLowerCase().includes('tile');
});
if (tileControlnetModel?.key !== validModel?.key) {
dispatch(tileControlnetModelChanged(validModel || null));
}
}, [dispatch, model?.base, modelConfigs, tileControlnetModel?.key]);
dispatch(tileControlnetModelChanged(validModel || null));
}, [model?.base, modelConfigs, dispatch]);
const isBaseModelCompatible = useMemo(() => {
return model && ['sd-1', 'sdxl'].includes(model.base);

View File

@@ -1,33 +1,23 @@
import type { PayloadAction, Selector } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit';
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { isPlainObject } from 'es-toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
import { paramsReset } from 'features/controlLayers/store/paramsSlice';
import { atom } from 'nanostores';
import { stylePresetsApi } from 'services/api/endpoints/stylePresets';
import { assert } from 'tsafe';
import z from 'zod';
const zStylePresetState = z.object({
activeStylePresetId: z.string().nullable(),
searchTerm: z.string(),
viewMode: z.boolean(),
showPromptPreviews: z.boolean(),
});
import type { StylePresetState } from './types';
type StylePresetState = z.infer<typeof zStylePresetState>;
const getInitialState = (): StylePresetState => ({
const initialState: StylePresetState = {
activeStylePresetId: null,
searchTerm: '',
viewMode: false,
showPromptPreviews: false,
});
};
const slice = createSlice({
export const stylePresetSlice = createSlice({
name: 'stylePreset',
initialState: getInitialState(),
initialState: initialState,
reducers: {
activeStylePresetIdChanged: (state, action: PayloadAction<string | null>) => {
state.activeStylePresetId = action.payload;
@@ -44,7 +34,7 @@ const slice = createSlice({
},
extraReducers(builder) {
builder.addCase(paramsReset, () => {
return getInitialState();
return deepClone(initialState);
});
builder.addMatcher(stylePresetsApi.endpoints.deleteStylePreset.matchFulfilled, (state, action) => {
if (state.activeStylePresetId === null) {
@@ -68,21 +58,21 @@ const slice = createSlice({
});
export const { activeStylePresetIdChanged, searchTermChanged, viewModeChanged, showPromptPreviewsChanged } =
slice.actions;
stylePresetSlice.actions;
export const stylePresetSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zStylePresetState,
getInitialState,
persistConfig: {
migrate: (state) => {
assert(isPlainObject(state));
if (!('_version' in state)) {
state._version = 1;
}
return zStylePresetState.parse(state);
},
},
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrateStylePresetState = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
return state;
};
export const stylePresetPersistConfig: PersistConfig<StylePresetState> = {
name: stylePresetSlice.name,
initialState,
migrate: migrateStylePresetState,
persistDenylist: [],
};
export const selectStylePresetSlice = (state: RootState) => state.stylePreset;

View File

@@ -0,0 +1,6 @@
export type StylePresetState = {
activeStylePresetId: string | null;
searchTerm: string;
viewMode: boolean;
showPromptPreviews: boolean;
};

View File

@@ -14,11 +14,11 @@ import {
Switch,
Text,
} from '@invoke-ai/ui-library';
import { useClearStorage } from 'app/contexts/clear-storage-context';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { buildUseBoolean } from 'common/hooks/useBoolean';
import { useClearStorage } from 'common/hooks/useClearStorage';
import { selectShouldUseCPUNoise, shouldUseCpuNoiseChanged } from 'features/controlLayers/store/paramsSlice';
import { useRefreshAfterResetModal } from 'features/system/components/SettingsModal/RefreshAfterResetModal';
import { SettingsDeveloperLogIsEnabled } from 'features/system/components/SettingsModal/SettingsDeveloperLogIsEnabled';

View File

@@ -1,25 +1,193 @@
import type { PayloadAction, Selector } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit';
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { getDefaultAppConfig, type PartialAppConfig, zAppConfig } from 'app/types/invokeai';
import type { AppConfig, NumericalParameterConfig, PartialAppConfig } from 'app/types/invokeai';
import { merge } from 'es-toolkit/compat';
import z from 'zod';
const zConfigState = z.object({
...zAppConfig.shape,
didLoad: z.boolean(),
});
type ConfigState = z.infer<typeof zConfigState>;
const baseDimensionConfig: NumericalParameterConfig = {
initial: 512, // determined by model selection, unused in practice
sliderMin: 64,
sliderMax: 1536,
numberInputMin: 64,
numberInputMax: 4096,
fineStep: 8,
coarseStep: 64,
};
const getInitialState = (): ConfigState => ({
...getDefaultAppConfig(),
const initialConfigState: AppConfig & { didLoad: boolean } = {
didLoad: false,
});
isLocal: true,
shouldUpdateImagesOnConnect: false,
shouldFetchMetadataFromApi: false,
allowPrivateBoards: false,
allowPrivateStylePresets: false,
allowClientSideUpload: false,
allowPublishWorkflows: false,
allowPromptExpansion: false,
shouldShowCredits: false,
disabledTabs: [],
disabledFeatures: ['lightbox', 'faceRestore', 'batches'],
disabledSDFeatures: ['variation', 'symmetry', 'hires', 'perlinNoise', 'noiseThreshold'],
nodesAllowlist: undefined,
nodesDenylist: undefined,
sd: {
disabledControlNetModels: [],
disabledControlNetProcessors: [],
iterations: {
initial: 1,
sliderMin: 1,
sliderMax: 1000,
numberInputMin: 1,
numberInputMax: 10000,
fineStep: 1,
coarseStep: 1,
},
width: { ...baseDimensionConfig },
height: { ...baseDimensionConfig },
boundingBoxWidth: { ...baseDimensionConfig },
boundingBoxHeight: { ...baseDimensionConfig },
scaledBoundingBoxWidth: { ...baseDimensionConfig },
scaledBoundingBoxHeight: { ...baseDimensionConfig },
scheduler: 'dpmpp_3m_k',
vaePrecision: 'fp32',
steps: {
initial: 30,
sliderMin: 1,
sliderMax: 100,
numberInputMin: 1,
numberInputMax: 500,
fineStep: 1,
coarseStep: 1,
},
guidance: {
initial: 7,
sliderMin: 1,
sliderMax: 20,
numberInputMin: 1,
numberInputMax: 200,
fineStep: 0.1,
coarseStep: 0.5,
},
img2imgStrength: {
initial: 0.7,
sliderMin: 0,
sliderMax: 1,
numberInputMin: 0,
numberInputMax: 1,
fineStep: 0.01,
coarseStep: 0.05,
},
canvasCoherenceStrength: {
initial: 0.3,
sliderMin: 0,
sliderMax: 1,
numberInputMin: 0,
numberInputMax: 1,
fineStep: 0.01,
coarseStep: 0.05,
},
hrfStrength: {
initial: 0.45,
sliderMin: 0,
sliderMax: 1,
numberInputMin: 0,
numberInputMax: 1,
fineStep: 0.01,
coarseStep: 0.05,
},
canvasCoherenceEdgeSize: {
initial: 16,
sliderMin: 0,
sliderMax: 128,
numberInputMin: 0,
numberInputMax: 1024,
fineStep: 8,
coarseStep: 16,
},
cfgRescaleMultiplier: {
initial: 0,
sliderMin: 0,
sliderMax: 0.99,
numberInputMin: 0,
numberInputMax: 0.99,
fineStep: 0.05,
coarseStep: 0.1,
},
clipSkip: {
initial: 0,
sliderMin: 0,
sliderMax: 12, // determined by model selection, unused in practice
numberInputMin: 0,
numberInputMax: 12, // determined by model selection, unused in practice
fineStep: 1,
coarseStep: 1,
},
infillPatchmatchDownscaleSize: {
initial: 1,
sliderMin: 1,
sliderMax: 10,
numberInputMin: 1,
numberInputMax: 10,
fineStep: 1,
coarseStep: 1,
},
infillTileSize: {
initial: 32,
sliderMin: 16,
sliderMax: 64,
numberInputMin: 16,
numberInputMax: 256,
fineStep: 1,
coarseStep: 1,
},
maskBlur: {
initial: 16,
sliderMin: 0,
sliderMax: 128,
numberInputMin: 0,
numberInputMax: 512,
fineStep: 1,
coarseStep: 1,
},
ca: {
weight: {
initial: 1,
sliderMin: 0,
sliderMax: 2,
numberInputMin: -1,
numberInputMax: 2,
fineStep: 0.01,
coarseStep: 0.05,
},
},
dynamicPrompts: {
maxPrompts: {
initial: 100,
sliderMin: 1,
sliderMax: 1000,
numberInputMin: 1,
numberInputMax: 10000,
fineStep: 1,
coarseStep: 10,
},
},
},
flux: {
guidance: {
initial: 4,
sliderMin: 2,
sliderMax: 6,
numberInputMin: 1,
numberInputMax: 20,
fineStep: 0.1,
coarseStep: 0.5,
},
},
};
const slice = createSlice({
export const configSlice = createSlice({
name: 'config',
initialState: getInitialState(),
initialState: initialConfigState,
reducers: {
configChanged: (state, action: PayloadAction<PartialAppConfig>) => {
merge(state, action.payload);
@@ -28,16 +196,11 @@ const slice = createSlice({
},
});
export const { configChanged } = slice.actions;
export const configSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zConfigState,
getInitialState,
};
export const { configChanged } = configSlice.actions;
export const selectConfigSlice = (state: RootState) => state.config;
const createConfigSelector = <T>(selector: Selector<ConfigState, T>) => createSelector(selectConfigSlice, selector);
const createConfigSelector = <T>(selector: Selector<typeof initialConfigState, T>) =>
createSelector(selectConfigSlice, selector);
export const selectWidthConfig = createConfigSelector((config) => config.sd.width);
export const selectHeightConfig = createConfigSelector((config) => config.sd.height);

View File

@@ -3,15 +3,12 @@ import { createSelector, createSlice } from '@reduxjs/toolkit';
import type { LogNamespace } from 'app/logging/logger';
import { zLogNamespace } from 'app/logging/logger';
import { EMPTY_ARRAY } from 'app/store/constants';
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { isPlainObject } from 'es-toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import { uniq } from 'es-toolkit/compat';
import { assert } from 'tsafe';
import { type Language, type SystemState, zSystemState } from './types';
import type { Language, SystemState } from './types';
const getInitialState = (): SystemState => ({
const initialSystemState: SystemState = {
_version: 2,
shouldConfirmOnDelete: true,
shouldAntialiasProgressImage: false,
@@ -26,11 +23,11 @@ const getInitialState = (): SystemState => ({
logNamespaces: [...zLogNamespace.options],
shouldShowInvocationProgressDetail: false,
shouldHighlightFocusedRegions: false,
});
};
const slice = createSlice({
export const systemSlice = createSlice({
name: 'system',
initialState: getInitialState(),
initialState: initialSystemState,
reducers: {
setShouldConfirmOnDelete: (state, action: PayloadAction<boolean>) => {
state.shouldConfirmOnDelete = action.payload;
@@ -92,25 +89,25 @@ export const {
shouldConfirmOnNewSessionToggled,
setShouldShowInvocationProgressDetail,
setShouldHighlightFocusedRegions,
} = slice.actions;
} = systemSlice.actions;
export const systemSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zSystemState,
getInitialState,
persistConfig: {
migrate: (state) => {
assert(isPlainObject(state));
if (!('_version' in state)) {
state._version = 1;
}
if (state._version === 1) {
state.language = (state as SystemState).language.replace('_', '-');
state._version = 2;
}
return zSystemState.parse(state);
},
},
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrateSystemState = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
if (state._version === 1) {
state.language = (state as SystemState).language.replace('_', '-');
state._version = 2;
}
return state;
};
export const systemPersistConfig: PersistConfig<SystemState> = {
name: systemSlice.name,
initialState: initialSystemState,
migrate: migrateSystemState,
persistDenylist: [],
};
export const selectSystemSlice = (state: RootState) => state.system;

View File

@@ -1,4 +1,4 @@
import { zLogLevel, zLogNamespace } from 'app/logging/logger';
import type { LogLevel, LogNamespace } from 'app/logging/logger';
import { z } from 'zod';
const zLanguage = z.enum([
@@ -29,20 +29,19 @@ const zLanguage = z.enum([
export type Language = z.infer<typeof zLanguage>;
export const isLanguage = (v: unknown): v is Language => zLanguage.safeParse(v).success;
export const zSystemState = z.object({
_version: z.literal(2),
shouldConfirmOnDelete: z.boolean(),
shouldAntialiasProgressImage: z.boolean(),
shouldConfirmOnNewSession: z.boolean(),
language: zLanguage,
shouldUseNSFWChecker: z.boolean(),
shouldUseWatermarker: z.boolean(),
shouldEnableInformationalPopovers: z.boolean(),
shouldEnableModelDescriptions: z.boolean(),
logIsEnabled: z.boolean(),
logLevel: zLogLevel,
logNamespaces: z.array(zLogNamespace),
shouldShowInvocationProgressDetail: z.boolean(),
shouldHighlightFocusedRegions: z.boolean(),
});
export type SystemState = z.infer<typeof zSystemState>;
export interface SystemState {
_version: 2;
shouldConfirmOnDelete: boolean;
shouldAntialiasProgressImage: boolean;
shouldConfirmOnNewSession: boolean;
language: Language;
shouldUseNSFWChecker: boolean;
shouldUseWatermarker: boolean;
shouldEnableInformationalPopovers: boolean;
shouldEnableModelDescriptions: boolean;
logIsEnabled: boolean;
logLevel: LogLevel;
logNamespaces: LogNamespace[];
shouldShowInvocationProgressDetail: boolean;
shouldHighlightFocusedRegions: boolean;
}

View File

@@ -1,7 +1,6 @@
import { Box, Button, ButtonGroup, Flex, Grid, Heading, Icon, Text } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
import { imageDTOToImageWithDims } from 'features/controlLayers/store/util';
import { setUpscaleInitialImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
import {
@@ -38,7 +37,7 @@ export const UpscalingLaunchpadPanel = memo(() => {
const onUpload = useCallback(
(imageDTO: ImageDTO) => {
dispatch(upscaleInitialImageChanged(imageDTOToImageWithDims(imageDTO)));
dispatch(upscaleInitialImageChanged(imageDTO));
},
[dispatch]
);

View File

@@ -1,5 +1,5 @@
import type { DockviewApi, GridviewApi } from 'dockview';
import { DockviewApi as MockedDockviewApi, DockviewPanel, GridviewPanel } from 'dockview';
import { DockviewPanel, GridviewPanel } from 'dockview';
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
import type { NavigationAppApi } from './navigation-api';
@@ -12,7 +12,6 @@ import {
RIGHT_PANEL_MIN_SIZE_PX,
SETTINGS_PANEL_ID,
SWITCH_TABS_FAKE_DELAY_MS,
VIEWER_PANEL_ID,
WORKSPACE_PANEL_ID,
} from './shared';
@@ -49,7 +48,7 @@ vi.mock('dockview', async () => {
}
}
// Mock DockviewPanel class for instanceof checks
// Mock GridviewPanel class for instanceof checks
class MockDockviewPanel {
api = {
setActive: vi.fn(),
@@ -59,21 +58,10 @@ vi.mock('dockview', async () => {
};
}
// Mock DockviewApi class for instanceof checks
class MockDockviewApi {
panels = [];
activePanel = null;
toJSON = vi.fn();
fromJSON = vi.fn();
onDidLayoutChange = vi.fn();
onDidActivePanelChange = vi.fn();
}
return {
...actual,
GridviewPanel: MockGridviewPanel,
DockviewPanel: MockDockviewPanel,
DockviewApi: MockDockviewApi,
};
});
@@ -1117,393 +1105,4 @@ describe('AppNavigationApi', () => {
expect(initialize).not.toHaveBeenCalled();
});
});
describe('toggleViewerPanel', () => {
beforeEach(() => {
navigationApi.connectToApp(mockAppApi);
});
it('should switch to viewer panel when not currently on viewer', async () => {
const mockViewerPanel = createMockDockPanel();
navigationApi._registerPanel('generate', VIEWER_PANEL_ID, mockViewerPanel);
mockGetAppTab.mockReturnValue('generate');
// Set current panel to something other than viewer
navigationApi._currentActiveDockviewPanel.set('generate', SETTINGS_PANEL_ID);
const result = await navigationApi.toggleViewerPanel();
expect(result).toBe(true);
expect(mockViewerPanel.api.setActive).toHaveBeenCalledOnce();
});
it('should switch to previous panel when on viewer and previous panel exists', async () => {
const mockPreviousPanel = createMockDockPanel();
const mockViewerPanel = createMockDockPanel();
navigationApi._registerPanel('generate', SETTINGS_PANEL_ID, mockPreviousPanel);
navigationApi._registerPanel('generate', VIEWER_PANEL_ID, mockViewerPanel);
mockGetAppTab.mockReturnValue('generate');
// Set current panel to viewer and previous to settings
navigationApi._currentActiveDockviewPanel.set('generate', VIEWER_PANEL_ID);
navigationApi._prevActiveDockviewPanel.set('generate', SETTINGS_PANEL_ID);
const result = await navigationApi.toggleViewerPanel();
expect(result).toBe(true);
expect(mockPreviousPanel.api.setActive).toHaveBeenCalledOnce();
expect(mockViewerPanel.api.setActive).not.toHaveBeenCalled();
});
it('should switch to launchpad when on viewer and no valid previous panel', async () => {
const mockLaunchpadPanel = createMockDockPanel();
const mockViewerPanel = createMockDockPanel();
navigationApi._registerPanel('generate', LAUNCHPAD_PANEL_ID, mockLaunchpadPanel);
navigationApi._registerPanel('generate', VIEWER_PANEL_ID, mockViewerPanel);
mockGetAppTab.mockReturnValue('generate');
// Set current panel to viewer and no previous panel
navigationApi._currentActiveDockviewPanel.set('generate', VIEWER_PANEL_ID);
navigationApi._prevActiveDockviewPanel.set('generate', null);
const result = await navigationApi.toggleViewerPanel();
expect(result).toBe(true);
expect(mockLaunchpadPanel.api.setActive).toHaveBeenCalledOnce();
expect(mockViewerPanel.api.setActive).not.toHaveBeenCalled();
});
it('should switch to launchpad when on viewer and previous panel is also viewer', async () => {
const mockLaunchpadPanel = createMockDockPanel();
const mockViewerPanel = createMockDockPanel();
navigationApi._registerPanel('generate', LAUNCHPAD_PANEL_ID, mockLaunchpadPanel);
navigationApi._registerPanel('generate', VIEWER_PANEL_ID, mockViewerPanel);
mockGetAppTab.mockReturnValue('generate');
// Set current panel to viewer and previous panel was also viewer
navigationApi._currentActiveDockviewPanel.set('generate', VIEWER_PANEL_ID);
navigationApi._prevActiveDockviewPanel.set('generate', VIEWER_PANEL_ID);
const result = await navigationApi.toggleViewerPanel();
expect(result).toBe(true);
expect(mockLaunchpadPanel.api.setActive).toHaveBeenCalledOnce();
expect(mockViewerPanel.api.setActive).not.toHaveBeenCalled();
});
it('should return false when no active tab', async () => {
mockGetAppTab.mockReturnValue(null);
const result = await navigationApi.toggleViewerPanel();
expect(result).toBe(false);
});
it('should return false when viewer panel is not registered', async () => {
mockGetAppTab.mockReturnValue('generate');
navigationApi._currentActiveDockviewPanel.set('generate', SETTINGS_PANEL_ID);
// Don't register viewer panel
const result = await navigationApi.toggleViewerPanel();
expect(result).toBe(false);
});
it('should return false when previous panel is not registered', async () => {
const mockViewerPanel = createMockDockPanel();
navigationApi._registerPanel('generate', VIEWER_PANEL_ID, mockViewerPanel);
mockGetAppTab.mockReturnValue('generate');
// Set current to viewer and previous to unregistered panel
navigationApi._currentActiveDockviewPanel.set('generate', VIEWER_PANEL_ID);
navigationApi._prevActiveDockviewPanel.set('generate', 'unregistered-panel');
const result = await navigationApi.toggleViewerPanel();
expect(result).toBe(false);
});
it('should return false when launchpad panel is not registered as fallback', async () => {
const mockViewerPanel = createMockDockPanel();
navigationApi._registerPanel('generate', VIEWER_PANEL_ID, mockViewerPanel);
mockGetAppTab.mockReturnValue('generate');
// Set current to viewer and no previous panel, but don't register launchpad
navigationApi._currentActiveDockviewPanel.set('generate', VIEWER_PANEL_ID);
navigationApi._prevActiveDockviewPanel.set('generate', null);
const result = await navigationApi.toggleViewerPanel();
expect(result).toBe(false);
});
it('should work across different tabs independently', async () => {
const mockViewerPanel1 = createMockDockPanel();
const mockViewerPanel2 = createMockDockPanel();
const mockSettingsPanel1 = createMockDockPanel();
const mockSettingsPanel2 = createMockDockPanel();
const mockLaunchpadPanel = createMockDockPanel();
navigationApi._registerPanel('generate', VIEWER_PANEL_ID, mockViewerPanel1);
navigationApi._registerPanel('generate', SETTINGS_PANEL_ID, mockSettingsPanel1);
navigationApi._registerPanel('canvas', VIEWER_PANEL_ID, mockViewerPanel2);
navigationApi._registerPanel('canvas', SETTINGS_PANEL_ID, mockSettingsPanel2);
navigationApi._registerPanel('canvas', LAUNCHPAD_PANEL_ID, mockLaunchpadPanel);
// Set up different states for different tabs
navigationApi._currentActiveDockviewPanel.set('generate', SETTINGS_PANEL_ID);
navigationApi._currentActiveDockviewPanel.set('canvas', VIEWER_PANEL_ID);
navigationApi._prevActiveDockviewPanel.set('canvas', SETTINGS_PANEL_ID);
// Test generate tab (should switch to viewer)
mockGetAppTab.mockReturnValue('generate');
const result1 = await navigationApi.toggleViewerPanel();
expect(result1).toBe(true);
expect(mockViewerPanel1.api.setActive).toHaveBeenCalledOnce();
// Test canvas tab (should switch to previous panel - settings panel in canvas)
mockGetAppTab.mockReturnValue('canvas');
const result2 = await navigationApi.toggleViewerPanel();
expect(result2).toBe(true);
expect(mockSettingsPanel2.api.setActive).toHaveBeenCalledOnce();
});
it('should handle sequence of viewer toggles correctly', async () => {
const mockViewerPanel = createMockDockPanel();
const mockSettingsPanel = createMockDockPanel();
const mockLaunchpadPanel = createMockDockPanel();
navigationApi._registerPanel('generate', VIEWER_PANEL_ID, mockViewerPanel);
navigationApi._registerPanel('generate', SETTINGS_PANEL_ID, mockSettingsPanel);
navigationApi._registerPanel('generate', LAUNCHPAD_PANEL_ID, mockLaunchpadPanel);
mockGetAppTab.mockReturnValue('generate');
// Start on settings panel
navigationApi._currentActiveDockviewPanel.set('generate', SETTINGS_PANEL_ID);
navigationApi._prevActiveDockviewPanel.set('generate', null);
// First toggle: settings -> viewer
const result1 = await navigationApi.toggleViewerPanel();
expect(result1).toBe(true);
expect(mockViewerPanel.api.setActive).toHaveBeenCalledOnce();
// Simulate panel change tracking (normally done by dockview listener)
navigationApi._prevActiveDockviewPanel.set('generate', SETTINGS_PANEL_ID);
navigationApi._currentActiveDockviewPanel.set('generate', VIEWER_PANEL_ID);
// Second toggle: viewer -> settings (previous panel)
const result2 = await navigationApi.toggleViewerPanel();
expect(result2).toBe(true);
expect(mockSettingsPanel.api.setActive).toHaveBeenCalledOnce();
// Simulate panel change tracking again
navigationApi._prevActiveDockviewPanel.set('generate', VIEWER_PANEL_ID);
navigationApi._currentActiveDockviewPanel.set('generate', SETTINGS_PANEL_ID);
// Third toggle: settings -> viewer again
const result3 = await navigationApi.toggleViewerPanel();
expect(result3).toBe(true);
expect(mockViewerPanel.api.setActive).toHaveBeenCalledTimes(2);
});
});
describe('Disposable Cleanup', () => {
beforeEach(() => {
navigationApi.connectToApp(mockAppApi);
});
it('should add disposable functions for a tab', () => {
const dispose1 = vi.fn();
const dispose2 = vi.fn();
navigationApi._addDisposeForTab('generate', dispose1);
navigationApi._addDisposeForTab('generate', dispose2);
// Check that disposables are stored
const disposables = navigationApi._disposablesForTab.get('generate');
expect(disposables).toBeDefined();
expect(disposables?.size).toBe(2);
expect(disposables?.has(dispose1)).toBe(true);
expect(disposables?.has(dispose2)).toBe(true);
});
it('should handle multiple tabs independently', () => {
const dispose1 = vi.fn();
const dispose2 = vi.fn();
const dispose3 = vi.fn();
navigationApi._addDisposeForTab('generate', dispose1);
navigationApi._addDisposeForTab('generate', dispose2);
navigationApi._addDisposeForTab('canvas', dispose3);
const generateDisposables = navigationApi._disposablesForTab.get('generate');
const canvasDisposables = navigationApi._disposablesForTab.get('canvas');
expect(generateDisposables?.size).toBe(2);
expect(canvasDisposables?.size).toBe(1);
expect(generateDisposables?.has(dispose1)).toBe(true);
expect(generateDisposables?.has(dispose2)).toBe(true);
expect(canvasDisposables?.has(dispose3)).toBe(true);
});
it('should call all dispose functions when unregistering a tab', () => {
const dispose1 = vi.fn();
const dispose2 = vi.fn();
const dispose3 = vi.fn();
// Add disposables for generate tab
navigationApi._addDisposeForTab('generate', dispose1);
navigationApi._addDisposeForTab('generate', dispose2);
// Add disposable for canvas tab (should not be called)
navigationApi._addDisposeForTab('canvas', dispose3);
// Unregister generate tab
navigationApi.unregisterTab('generate');
// Check that generate tab disposables were called
expect(dispose1).toHaveBeenCalledOnce();
expect(dispose2).toHaveBeenCalledOnce();
// Check that canvas tab disposable was not called
expect(dispose3).not.toHaveBeenCalled();
// Check that generate tab disposables are cleared
expect(navigationApi._disposablesForTab.has('generate')).toBe(false);
// Check that canvas tab disposables remain
expect(navigationApi._disposablesForTab.has('canvas')).toBe(true);
});
it('should handle unregistering tab with no disposables gracefully', () => {
// Should not throw when unregistering tab with no disposables
expect(() => navigationApi.unregisterTab('generate')).not.toThrow();
});
it('should handle duplicate dispose functions', () => {
const dispose1 = vi.fn();
// Add the same dispose function twice
navigationApi._addDisposeForTab('generate', dispose1);
navigationApi._addDisposeForTab('generate', dispose1);
const disposables = navigationApi._disposablesForTab.get('generate');
// Set should contain only one instance (sets don't allow duplicates)
expect(disposables?.size).toBe(1);
navigationApi.unregisterTab('generate');
// Should be called only once despite being added twice
expect(dispose1).toHaveBeenCalledOnce();
});
it('should automatically add dispose functions during container registration with DockviewApi', () => {
const tab = 'generate';
const viewId = 'myView';
mockGetStorage.mockReturnValue(undefined);
const initialize = vi.fn();
const panel = { id: 'p1' };
const mockDispose = vi.fn();
// Create a mock that will pass the instanceof DockviewApi check
const mockApi = Object.create(MockedDockviewApi.prototype);
Object.assign(mockApi, {
panels: [panel],
activePanel: { id: 'p1' },
toJSON: vi.fn(() => ({ foo: 'bar' })),
onDidLayoutChange: vi.fn(() => ({ dispose: vi.fn() })),
onDidActivePanelChange: vi.fn(() => ({ dispose: mockDispose })),
});
navigationApi.registerContainer(tab, viewId, mockApi, initialize);
// Check that dispose function was added to disposables
const disposables = navigationApi._disposablesForTab.get(tab);
expect(disposables).toBeDefined();
expect(disposables?.size).toBe(1);
// Unregister tab and check dispose was called
navigationApi.unregisterTab(tab);
expect(mockDispose).toHaveBeenCalledOnce();
});
it('should not add dispose functions for GridviewApi during container registration', () => {
const tab = 'generate';
const viewId = 'myView';
mockGetStorage.mockReturnValue(undefined);
const initialize = vi.fn();
const panel = { id: 'p1' };
// Mock GridviewApi (not DockviewApi)
const mockApi = {
panels: [panel],
toJSON: vi.fn(() => ({ foo: 'bar' })),
onDidLayoutChange: vi.fn(() => ({ dispose: vi.fn() })),
} as unknown as GridviewApi;
navigationApi.registerContainer(tab, viewId, mockApi, initialize);
// Check that no dispose function was added for GridviewApi
const disposables = navigationApi._disposablesForTab.get(tab);
expect(disposables).toBeUndefined();
});
it('should handle dispose function errors gracefully', () => {
const goodDispose = vi.fn();
const errorDispose = vi.fn(() => {
throw new Error('Dispose error');
});
const anotherGoodDispose = vi.fn();
navigationApi._addDisposeForTab('generate', goodDispose);
navigationApi._addDisposeForTab('generate', errorDispose);
navigationApi._addDisposeForTab('generate', anotherGoodDispose);
// Should not throw even if one dispose function throws
expect(() => navigationApi.unregisterTab('generate')).not.toThrow();
// All dispose functions should have been called
expect(goodDispose).toHaveBeenCalledOnce();
expect(errorDispose).toHaveBeenCalledOnce();
expect(anotherGoodDispose).toHaveBeenCalledOnce();
});
it('should clear panel tracking state when unregistering tab', () => {
const tab = 'generate';
// Set up some panel tracking state
navigationApi._currentActiveDockviewPanel.set(tab, VIEWER_PANEL_ID);
navigationApi._prevActiveDockviewPanel.set(tab, SETTINGS_PANEL_ID);
// Add some disposables
const dispose1 = vi.fn();
const dispose2 = vi.fn();
navigationApi._addDisposeForTab(tab, dispose1);
navigationApi._addDisposeForTab(tab, dispose2);
// Verify state exists before unregistering
expect(navigationApi._currentActiveDockviewPanel.has(tab)).toBe(true);
expect(navigationApi._prevActiveDockviewPanel.has(tab)).toBe(true);
expect(navigationApi._disposablesForTab.has(tab)).toBe(true);
// Unregister tab
navigationApi.unregisterTab(tab);
// Verify all state is cleared
expect(navigationApi._currentActiveDockviewPanel.has(tab)).toBe(false);
expect(navigationApi._prevActiveDockviewPanel.has(tab)).toBe(false);
expect(navigationApi._disposablesForTab.has(tab)).toBe(false);
// Verify dispose functions were called
expect(dispose1).toHaveBeenCalledOnce();
expect(dispose2).toHaveBeenCalledOnce();
});
});
});

View File

@@ -1,21 +1,19 @@
import { logger } from 'app/logging/logger';
import { createDeferredPromise, type Deferred } from 'common/util/createDeferredPromise';
import { parseify } from 'common/util/serialize';
import type { GridviewApi, IDockviewPanel, IGridviewPanel } from 'dockview';
import { DockviewApi, GridviewPanel } from 'dockview';
import type { DockviewApi, GridviewApi, IDockviewPanel, IGridviewPanel } from 'dockview';
import { GridviewPanel } from 'dockview';
import { debounce } from 'es-toolkit';
import type { Serializable, TabName } from 'features/ui/store/uiTypes';
import type { Atom } from 'nanostores';
import { atom } from 'nanostores';
import {
LAUNCHPAD_PANEL_ID,
LEFT_PANEL_ID,
LEFT_PANEL_MIN_SIZE_PX,
RIGHT_PANEL_ID,
RIGHT_PANEL_MIN_SIZE_PX,
SWITCH_TABS_FAKE_DELAY_MS,
VIEWER_PANEL_ID,
} from './shared';
const log = logger('system');
@@ -71,37 +69,6 @@ export class NavigationApi {
private _$isLoading = atom(false);
$isLoading: Atom<boolean> = this._$isLoading;
/**
* Track the _previous_ active dockview panel for each tab.
*/
_prevActiveDockviewPanel: Map<TabName, string | null> = new Map();
/**
* Track the _current_ active dockview panel for each tab.
*/
_currentActiveDockviewPanel: Map<TabName, string | null> = new Map();
/**
* Map of disposables for each tab.
* This is used to clean up resources when a tab is unregistered.
*/
_disposablesForTab: Map<TabName, Set<() => void>> = new Map();
/**
* Convenience method to add a dispose function for a specific tab.
*/
/**
* Convenience method to add a dispose function for a specific tab.
*/
_addDisposeForTab = (tab: TabName, disposeFn: () => void): void => {
let disposables = this._disposablesForTab.get(tab);
if (!disposables) {
disposables = new Set<() => void>();
this._disposablesForTab.set(tab, disposables);
}
disposables.add(disposeFn);
};
/**
* Separator used to create unique keys for panels. Typo protection.
*/
@@ -242,18 +209,6 @@ export class NavigationApi {
this._registerPanel(tab, panel.id, panel);
}
// Set up tracking for active tab for this panel - needed for viewer toggle functionality
if (api instanceof DockviewApi) {
this._currentActiveDockviewPanel.set(tab, api.activePanel?.id ?? null);
this._prevActiveDockviewPanel.set(tab, null);
const { dispose } = api.onDidActivePanelChange((panel) => {
const previousPanelId = this._currentActiveDockviewPanel.get(tab);
this._prevActiveDockviewPanel.set(tab, previousPanelId ?? null);
this._currentActiveDockviewPanel.set(tab, panel?.id ?? null);
});
this._addDisposeForTab(tab, dispose);
}
api.onDidLayoutChange(
debounce(() => {
this._app?.storage.set(key, api.toJSON());
@@ -590,42 +545,6 @@ export class NavigationApi {
return true;
};
/**
* Toggle between the viewer panel and the previously focused dockview panel in the current tab.
* If currently on viewer and a previous panel exists, switch to the previous panel.
* If not on viewer, switch to viewer.
* If no previous panel exists, defaults to launchpad panel.
* Only operates on dockview panels (panels with tabs), not gridview panels.
*
* @returns Promise that resolves to true if successful, false otherwise
*/
toggleViewerPanel = (): Promise<boolean> => {
const activeTab = this._app?.activeTab.get() ?? null;
if (!activeTab) {
log.warn('No active tab found for viewer toggle');
return Promise.resolve(false);
}
const prevActiveDockviewPanel = this._prevActiveDockviewPanel.get(activeTab);
const currentActiveDockviewPanel = this._currentActiveDockviewPanel.get(activeTab);
let targetPanel;
if (currentActiveDockviewPanel !== VIEWER_PANEL_ID) {
targetPanel = VIEWER_PANEL_ID;
} else if (prevActiveDockviewPanel && prevActiveDockviewPanel !== VIEWER_PANEL_ID) {
targetPanel = prevActiveDockviewPanel;
} else {
targetPanel = LAUNCHPAD_PANEL_ID;
}
if (this.getRegisteredPanels(activeTab).includes(targetPanel)) {
return this.focusPanel(activeTab, targetPanel);
}
return Promise.resolve(false);
};
/**
* Check if a panel is registered.
* @param tab - The tab the panel belongs to
@@ -674,18 +593,6 @@ export class NavigationApi {
this.waiters.delete(key);
}
// Clear previous panel tracking for this tab
this._prevActiveDockviewPanel.delete(tab);
this._currentActiveDockviewPanel.delete(tab);
this._disposablesForTab.get(tab)?.forEach((disposeFn) => {
try {
disposeFn();
} catch (error) {
log.error({ error: parseify(error) }, `Error disposing resource for tab ${tab}`);
}
});
this._disposablesForTab.delete(tab);
log.trace(`Unregistered all panels for tab ${tab}`);
};
}

View File

@@ -1,13 +1,11 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { isPlainObject } from 'es-toolkit';
import { assert } from 'tsafe';
import type { PersistConfig, RootState } from 'app/store/store';
import { getInitialUIState, type UIState, zUIState } from './uiTypes';
import type { UIState } from './uiTypes';
import { getInitialUIState } from './uiTypes';
const slice = createSlice({
export const uiSlice = createSlice({
name: 'ui',
initialState: getInitialUIState(),
reducers: {
@@ -83,30 +81,29 @@ export const {
textAreaSizesStateChanged,
dockviewStorageKeyChanged,
pickerCompactViewStateChanged,
} = slice.actions;
} = uiSlice.actions;
export const selectUiSlice = (state: RootState) => state.ui;
export const uiSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zUIState,
getInitialState: getInitialUIState,
persistConfig: {
migrate: (state) => {
assert(isPlainObject(state));
if (!('_version' in state)) {
state._version = 1;
}
if (state._version === 1) {
state.activeTab = 'generation';
state._version = 2;
}
if (state._version === 2) {
state.activeTab = 'canvas';
state._version = 3;
}
return zUIState.parse(state);
},
persistDenylist: ['shouldShowImageDetails'],
},
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrateUIState = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
if (state._version === 1) {
state.activeTab = 'generation';
state._version = 2;
}
if (state._version === 2) {
state.activeTab = 'canvas';
state._version = 3;
}
return state;
};
export const uiPersistConfig: PersistConfig<UIState> = {
name: uiSlice.name,
initialState: getInitialUIState(),
migrate: migrateUIState,
persistDenylist: ['shouldShowImageDetails'],
};

View File

@@ -1,7 +1,8 @@
import { deepClone } from 'common/util/deepClone';
import { isPlainObject } from 'es-toolkit';
import { z } from 'zod';
export const zTabName = z.enum(['generate', 'canvas', 'upscaling', 'workflows', 'models', 'queue']);
const zTabName = z.enum(['generate', 'canvas', 'upscaling', 'workflows', 'models', 'queue']);
export type TabName = z.infer<typeof zTabName>;
const zPartialDimensions = z.object({
@@ -12,28 +13,18 @@ const zPartialDimensions = z.object({
const zSerializable = z.any().refine(isPlainObject);
export type Serializable = z.infer<typeof zSerializable>;
export const zUIState = z.object({
_version: z.literal(3),
activeTab: zTabName,
shouldShowImageDetails: z.boolean(),
shouldShowProgressInViewer: z.boolean(),
accordions: z.record(z.string(), z.boolean()),
expanders: z.record(z.string(), z.boolean()),
textAreaSizes: z.record(z.string(), zPartialDimensions),
panels: z.record(z.string(), zSerializable),
shouldShowNotificationV2: z.boolean(),
pickerCompactViewStates: z.record(z.string(), z.boolean()),
const zUIState = z.object({
_version: z.literal(3).default(3),
activeTab: zTabName.default('generate'),
shouldShowImageDetails: z.boolean().default(false),
shouldShowProgressInViewer: z.boolean().default(true),
accordions: z.record(z.string(), z.boolean()).default(() => ({})),
expanders: z.record(z.string(), z.boolean()).default(() => ({})),
textAreaSizes: z.record(z.string(), zPartialDimensions).default({}),
panels: z.record(z.string(), zSerializable).default({}),
shouldShowNotificationV2: z.boolean().default(true),
pickerCompactViewStates: z.record(z.string(), z.boolean()).default(() => ({})),
});
const INITIAL_STATE = zUIState.parse({});
export type UIState = z.infer<typeof zUIState>;
export const getInitialUIState = (): UIState => ({
_version: 3 as const,
activeTab: 'generate' as const,
shouldShowImageDetails: false,
shouldShowProgressInViewer: true,
accordions: {},
expanders: {},
textAreaSizes: {},
panels: {},
shouldShowNotificationV2: true,
pickerCompactViewStates: {},
});
export const getInitialUIState = (): UIState => deepClone(INITIAL_STATE);

View File

@@ -1,6 +1,5 @@
import { $openAPISchemaUrl } from 'app/store/nanostores/openAPISchemaUrl';
import type { OpenAPIV3_1 } from 'openapi-types';
import type { stringify } from 'querystring';
import type { paths } from 'services/api/schema';
import type { AppConfig, AppVersion } from 'services/api/types';
@@ -12,8 +11,7 @@ import { api, buildV1Url } from '..';
* buildAppInfoUrl('some-path')
* // '/api/v1/app/some-path'
*/
export const buildAppInfoUrl = (path: string = '', query?: Parameters<typeof stringify>[0]) =>
buildV1Url(`app/${path}`, query);
const buildAppInfoUrl = (path: string = '') => buildV1Url(`app/${path}`);
export const appInfoApi = api.injectEndpoints({
endpoints: (build) => ({
@@ -89,31 +87,6 @@ export const appInfoApi = api.injectEndpoints({
},
providesTags: ['Schema'],
}),
getClientStateByKey: build.query<
paths['/api/v1/app/client_state']['get']['responses']['200']['content']['application/json'],
paths['/api/v1/app/client_state']['get']['parameters']['query']
>({
query: () => ({
url: buildAppInfoUrl('client_state'),
method: 'GET',
}),
}),
setClientStateByKey: build.mutation<
paths['/api/v1/app/client_state']['post']['responses']['200']['content']['application/json'],
paths['/api/v1/app/client_state']['post']['requestBody']['content']['application/json']
>({
query: (body) => ({
url: buildAppInfoUrl('client_state'),
method: 'POST',
body,
}),
}),
deleteClientState: build.mutation<void, void>({
query: () => ({
url: buildAppInfoUrl('client_state'),
method: 'DELETE',
}),
}),
}),
});

View File

@@ -57,18 +57,13 @@ const tagTypes = [
// This is invalidated on reconnect. It should be used for queries that have changing data,
// especially related to the queue and generation.
'FetchOnReconnect',
'ClientState',
] as const;
export type ApiTagDescription = TagDescription<(typeof tagTypes)[number]>;
export const LIST_TAG = 'LIST';
export const LIST_ALL_TAG = 'LIST_ALL';
export const getBaseUrl = (): string => {
const baseUrl = $baseUrl.get();
return baseUrl || window.location.href.replace(/\/$/, '');
};
const dynamicBaseQuery: BaseQueryFn<string | FetchArgs, unknown, FetchBaseQueryError> = (args, api, extraOptions) => {
const baseUrl = $baseUrl.get();
const authToken = $authToken.get();
const projectId = $projectId.get();
const isOpenAPIRequest =
@@ -76,7 +71,7 @@ const dynamicBaseQuery: BaseQueryFn<string | FetchArgs, unknown, FetchBaseQueryE
(typeof args === 'string' && args.includes('openapi.json'));
const fetchBaseQueryArgs: FetchBaseQueryArgs = {
baseUrl: getBaseUrl(),
baseUrl: baseUrl || window.location.href.replace(/\/$/, ''),
};
// When fetching the openapi.json, we need to remove circular references from the JSON.

View File

@@ -1164,34 +1164,6 @@ export type paths = {
patch?: never;
trace?: never;
};
"/api/v1/app/client_state": {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
/**
* Get Client State By Key
* @description Gets the client state
*/
get: operations["get_client_state_by_key"];
put?: never;
/**
* Set Client State
* @description Sets the client state
*/
post: operations["set_client_state"];
/**
* Delete Client State
* @description Deletes the client state
*/
delete: operations["delete_client_state"];
options?: never;
head?: never;
patch?: never;
trace?: never;
};
"/api/v1/queue/{queue_id}/enqueue_batch": {
parameters: {
query?: never;
@@ -24725,101 +24697,6 @@ export interface operations {
};
};
};
get_client_state_by_key: {
parameters: {
query: {
/** @description Key to get */
key: string;
};
header?: never;
path?: never;
cookie?: never;
};
requestBody?: never;
responses: {
/** @description Successful Response */
200: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["JsonValue"] | null;
};
};
/** @description Validation Error */
422: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["HTTPValidationError"];
};
};
};
};
set_client_state: {
parameters: {
query: {
/** @description Key to set */
key: string;
};
header?: never;
path?: never;
cookie?: never;
};
requestBody: {
content: {
"application/json": components["schemas"]["JsonValue"];
};
};
responses: {
/** @description Successful Response */
200: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": unknown;
};
};
/** @description Validation Error */
422: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["HTTPValidationError"];
};
};
};
};
delete_client_state: {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
requestBody?: never;
responses: {
/** @description Successful Response */
200: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": unknown;
};
};
/** @description Client state deleted */
204: {
headers: {
[name: string]: unknown;
};
content?: never;
};
};
};
enqueue_batch: {
parameters: {
query?: never;

Some files were not shown because too many files have changed in this diff Show More