Compare commits

...

27 Commits

Author SHA1 Message Date
psychedelicious
54bda8e8e4 chore: bump version to v6.2.0a1 2025-07-25 17:16:51 +10:00
psychedelicious
c14889f055 tidy(ui): enable devmode redux checks 2025-07-25 17:15:19 +10:00
psychedelicious
86680f296a chore(ui): lint 2025-07-25 17:15:19 +10:00
psychedelicious
de0b7801a6 fix(ui): infinite loop when setting tile controlnet model 2025-07-25 17:15:19 +10:00
psychedelicious
a03c7ca4e3 fix(ui): do not store whole model configs in state 2025-07-25 17:15:19 +10:00
psychedelicious
32af53779d refactor(ui): just manually validate async stuff 2025-07-25 17:15:19 +10:00
psychedelicious
a8662953fc refactor(ui): work around zod async validation issue 2025-07-25 17:15:19 +10:00
psychedelicious
82cdfd83e4 fix(ui): check initial retrieval and set as last persisted 2025-07-25 17:15:19 +10:00
psychedelicious
3f3fdf0b43 chore(ui): bump zod to latest
Checking if it fixes an issue w/ async validators
2025-07-25 17:15:18 +10:00
psychedelicious
53dbd5a7c9 refactor(ui): use zod for all redux state 2025-07-25 17:15:18 +10:00
psychedelicious
bbe5979349 refactor(ui): use zod for all redux state (wip)
needed for confidence w/ state rehydration logic
2025-07-25 17:15:18 +10:00
psychedelicious
ca70540ddd feat(ui): iterate on storage api 2025-07-25 17:15:18 +10:00
psychedelicious
37e25ccbf7 refactor(ui): restructure persistence driver creation to support custom drivers 2025-07-25 17:15:18 +10:00
psychedelicious
28e7a83f98 revert(ui): temp changes to main.tsx for testing 2025-07-25 17:15:18 +10:00
psychedelicious
3b39912b1c revert(ui): temp disable eslint rule 2025-07-25 17:15:18 +10:00
psychedelicious
c76698f205 git: update gitignore 2025-07-25 17:15:18 +10:00
psychedelicious
8f27a393d8 wip 2025-07-25 17:15:18 +10:00
psychedelicious
84ff6dbe69 chore: ruff 2025-07-25 17:15:18 +10:00
psychedelicious
4620a2137c tests(app): service mocks 2025-07-25 17:15:18 +10:00
psychedelicious
8ddbd979dd chore(ui): lint 2025-07-25 17:15:17 +10:00
psychedelicious
19ec9d268e refactor(ui): iterate on persistence 2025-07-25 17:15:17 +10:00
psychedelicious
ab683802ba refactor(ui): iterate on persistence 2025-07-25 17:15:17 +10:00
psychedelicious
98957ec9ea refactor(ui): alternate approach to slice configs 2025-07-25 17:15:17 +10:00
psychedelicious
7936ee9b7f chore(ui): typegen 2025-07-25 17:15:17 +10:00
psychedelicious
a96b7afdfb feat(api): make client state key query not body 2025-07-25 17:15:17 +10:00
psychedelicious
bb58a70b70 refactor(ui): cleaner slice definitions 2025-07-25 17:15:17 +10:00
psychedelicious
aaa1e1a480 feat: server-side client state persistence 2025-07-25 17:15:17 +10:00
89 changed files with 2276 additions and 1305 deletions

View File

@@ -10,6 +10,7 @@ from invokeai.app.services.board_images.board_images_default import BoardImagesS
from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage
from invokeai.app.services.boards.boards_default import BoardService from invokeai.app.services.boards.boards_default import BoardService
from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService
from invokeai.app.services.client_state_persistence.client_state_persistence_sqlite import ClientStatePersistenceSqlite
from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.download.download_default import DownloadQueueService from invokeai.app.services.download.download_default import DownloadQueueService
from invokeai.app.services.events.events_fastapievents import FastAPIEventService from invokeai.app.services.events.events_fastapievents import FastAPIEventService
@@ -151,6 +152,7 @@ class ApiDependencies:
style_preset_records = SqliteStylePresetRecordsStorage(db=db) style_preset_records = SqliteStylePresetRecordsStorage(db=db)
style_preset_image_files = StylePresetImageFileStorageDisk(style_presets_folder / "images") style_preset_image_files = StylePresetImageFileStorageDisk(style_presets_folder / "images")
workflow_thumbnails = WorkflowThumbnailFileStorageDisk(workflow_thumbnails_folder) workflow_thumbnails = WorkflowThumbnailFileStorageDisk(workflow_thumbnails_folder)
client_state_persistence = ClientStatePersistenceSqlite(db=db)
services = InvocationServices( services = InvocationServices(
board_image_records=board_image_records, board_image_records=board_image_records,
@@ -181,6 +183,7 @@ class ApiDependencies:
style_preset_records=style_preset_records, style_preset_records=style_preset_records,
style_preset_image_files=style_preset_image_files, style_preset_image_files=style_preset_image_files,
workflow_thumbnails=workflow_thumbnails, workflow_thumbnails=workflow_thumbnails,
client_state_persistence=client_state_persistence,
) )
ApiDependencies.invoker = Invoker(services) ApiDependencies.invoker = Invoker(services)

View File

@@ -5,9 +5,9 @@ from pathlib import Path
from typing import Optional from typing import Optional
import torch import torch
from fastapi import Body from fastapi import Body, HTTPException, Query
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, JsonValue
from invokeai.app.api.dependencies import ApiDependencies from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.invocations.upscale import ESRGAN_MODELS from invokeai.app.invocations.upscale import ESRGAN_MODELS
@@ -173,3 +173,50 @@ async def disable_invocation_cache() -> None:
async def get_invocation_cache_status() -> InvocationCacheStatus: async def get_invocation_cache_status() -> InvocationCacheStatus:
"""Clears the invocation cache""" """Clears the invocation cache"""
return ApiDependencies.invoker.services.invocation_cache.get_status() return ApiDependencies.invoker.services.invocation_cache.get_status()
@app_router.get(
"/client_state",
operation_id="get_client_state_by_key",
response_model=JsonValue | None,
)
async def get_client_state_by_key(
key: str = Query(..., description="Key to get"),
) -> JsonValue | None:
"""Gets the client state"""
try:
return ApiDependencies.invoker.services.client_state_persistence.get_by_key(key)
except Exception as e:
logging.error(f"Error getting client state: {e}")
raise HTTPException(status_code=500, detail="Error setting client state")
@app_router.post(
"/client_state",
operation_id="set_client_state",
response_model=None,
)
async def set_client_state(
key: str = Query(..., description="Key to set"),
value: JsonValue = Body(..., description="Value of the key"),
) -> None:
"""Sets the client state"""
try:
ApiDependencies.invoker.services.client_state_persistence.set_by_key(key, value)
except Exception as e:
logging.error(f"Error setting client state: {e}")
raise HTTPException(status_code=500, detail="Error setting client state")
@app_router.delete(
"/client_state",
operation_id="delete_client_state",
responses={204: {"description": "Client state deleted"}},
)
async def delete_client_state() -> None:
"""Deletes the client state"""
try:
ApiDependencies.invoker.services.client_state_persistence.delete()
except Exception as e:
logging.error(f"Error deleting client state: {e}")
raise HTTPException(status_code=500, detail="Error deleting client state")

View File

@@ -0,0 +1,35 @@
from abc import ABC, abstractmethod
from pydantic import JsonValue
class ClientStatePersistenceABC(ABC):
"""
Base class for client persistence implementations.
This class defines the interface for persisting client data.
"""
@abstractmethod
def set_by_key(self, key: str, value: JsonValue) -> None:
"""
Store the data for the client.
:param data: The client data to be stored.
"""
pass
@abstractmethod
def get_by_key(self, key: str) -> JsonValue | None:
"""
Get the data for the client.
:return: The client data.
"""
pass
@abstractmethod
def delete(self) -> None:
"""
Delete the data for the client.
"""
pass

View File

@@ -0,0 +1,65 @@
import json
from pydantic import JsonValue
from invokeai.app.services.client_state_persistence.client_state_persistence_base import ClientStatePersistenceABC
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
class ClientStatePersistenceSqlite(ClientStatePersistenceABC):
"""
Base class for client persistence implementations.
This class defines the interface for persisting client data.
"""
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._db = db
self._default_row_id = 1
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
def set_by_key(self, key: str, value: JsonValue) -> None:
state = self.get() or {}
state.update({key: value})
with self._db.transaction() as cursor:
cursor.execute(
f"""
INSERT INTO client_state (id, data)
VALUES ({self._default_row_id}, ?)
ON CONFLICT(id) DO UPDATE
SET data = excluded.data;
""",
(json.dumps(state),),
)
def get(self) -> dict[str, JsonValue] | None:
with self._db.transaction() as cursor:
cursor.execute(
f"""
SELECT data FROM client_state
WHERE id = {self._default_row_id}
"""
)
row = cursor.fetchone()
if row is None:
return None
return json.loads(row[0])
def get_by_key(self, key: str) -> JsonValue | None:
state = self.get()
if state is None:
return None
return state.get(key, None)
def delete(self) -> None:
with self._db.transaction() as cursor:
cursor.execute(
f"""
DELETE FROM client_state
WHERE id = {self._default_row_id}
"""
)

View File

@@ -17,6 +17,7 @@ if TYPE_CHECKING:
from invokeai.app.services.board_records.board_records_base import BoardRecordStorageBase from invokeai.app.services.board_records.board_records_base import BoardRecordStorageBase
from invokeai.app.services.boards.boards_base import BoardServiceABC from invokeai.app.services.boards.boards_base import BoardServiceABC
from invokeai.app.services.bulk_download.bulk_download_base import BulkDownloadBase from invokeai.app.services.bulk_download.bulk_download_base import BulkDownloadBase
from invokeai.app.services.client_state_persistence.client_state_persistence_base import ClientStatePersistenceABC
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadQueueServiceBase from invokeai.app.services.download import DownloadQueueServiceBase
from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.events.events_base import EventServiceBase
@@ -73,6 +74,7 @@ class InvocationServices:
style_preset_records: "StylePresetRecordsStorageBase", style_preset_records: "StylePresetRecordsStorageBase",
style_preset_image_files: "StylePresetImageFileStorageBase", style_preset_image_files: "StylePresetImageFileStorageBase",
workflow_thumbnails: "WorkflowThumbnailServiceBase", workflow_thumbnails: "WorkflowThumbnailServiceBase",
client_state_persistence: "ClientStatePersistenceABC",
): ):
self.board_images = board_images self.board_images = board_images
self.board_image_records = board_image_records self.board_image_records = board_image_records
@@ -102,3 +104,4 @@ class InvocationServices:
self.style_preset_records = style_preset_records self.style_preset_records = style_preset_records
self.style_preset_image_files = style_preset_image_files self.style_preset_image_files = style_preset_image_files
self.workflow_thumbnails = workflow_thumbnails self.workflow_thumbnails = workflow_thumbnails
self.client_state_persistence = client_state_persistence

View File

@@ -23,6 +23,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_17 import
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_18 import build_migration_18 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_18 import build_migration_18
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_19 import build_migration_19 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_19 import build_migration_19
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_20 import build_migration_20 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_20 import build_migration_20
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_21 import build_migration_21
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
@@ -63,6 +64,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_18()) migrator.register_migration(build_migration_18())
migrator.register_migration(build_migration_19(app_config=config)) migrator.register_migration(build_migration_19(app_config=config))
migrator.register_migration(build_migration_20()) migrator.register_migration(build_migration_20())
migrator.register_migration(build_migration_21())
migrator.run_migrations() migrator.run_migrations()
return db return db

View File

@@ -0,0 +1,40 @@
import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration21Callback:
def __call__(self, cursor: sqlite3.Cursor) -> None:
cursor.execute(
"""
CREATE TABLE client_state (
id INTEGER PRIMARY KEY CHECK(id = 1),
data TEXT NOT NULL, -- Frontend will handle the shape of this data
updated_at DATETIME NOT NULL DEFAULT (CURRENT_TIMESTAMP)
);
"""
)
cursor.execute(
"""
CREATE TRIGGER tg_client_state_updated_at
AFTER UPDATE ON client_state
FOR EACH ROW
BEGIN
UPDATE client_state
SET updated_at = CURRENT_TIMESTAMP
WHERE id = OLD.id;
END;
"""
)
def build_migration_21() -> Migration:
"""Builds the migration object for migrating from version 20 to version 21. This includes:
- Creating the `client_state` table.
- Adding a trigger to update the `updated_at` field on updates.
"""
return Migration(
from_version=20,
to_version=21,
callback=Migration21Callback(),
)

View File

@@ -44,4 +44,5 @@ yalc.lock
# vitest # vitest
tsconfig.vitest-temp.json tsconfig.vitest-temp.json
coverage/ coverage/
*.tgz

View File

@@ -26,7 +26,7 @@ i18n.use(initReactI18next).init({
returnNull: false, returnNull: false,
}); });
const store = createStore(undefined, false); const store = createStore({ driver: { getItem: () => {}, setItem: () => {} }, persistThrottle: 2000 });
$store.set(store); $store.set(store);
$baseUrl.set('http://localhost:9090'); $baseUrl.set('http://localhost:9090');

View File

@@ -197,6 +197,10 @@ export default [
importNames: ['isEqual'], importNames: ['isEqual'],
message: 'Please use objectEquals from @observ33r/object-equals instead.', message: 'Please use objectEquals from @observ33r/object-equals instead.',
}, },
{
name: 'zod/v3',
message: 'Import from zod instead.',
},
], ],
}, },
], ],

View File

@@ -63,7 +63,6 @@
"framer-motion": "^11.10.0", "framer-motion": "^11.10.0",
"i18next": "^25.3.2", "i18next": "^25.3.2",
"i18next-http-backend": "^3.0.2", "i18next-http-backend": "^3.0.2",
"idb-keyval": "6.2.2",
"jsondiffpatch": "^0.7.3", "jsondiffpatch": "^0.7.3",
"konva": "^9.3.22", "konva": "^9.3.22",
"linkify-react": "^4.3.1", "linkify-react": "^4.3.1",
@@ -103,7 +102,7 @@
"use-debounce": "^10.0.5", "use-debounce": "^10.0.5",
"use-device-pixel-ratio": "^1.1.2", "use-device-pixel-ratio": "^1.1.2",
"uuid": "^11.1.0", "uuid": "^11.1.0",
"zod": "^4.0.5", "zod": "^4.0.10",
"zod-validation-error": "^3.5.2" "zod-validation-error": "^3.5.2"
}, },
"peerDependencies": { "peerDependencies": {

View File

@@ -80,9 +80,6 @@ importers:
i18next-http-backend: i18next-http-backend:
specifier: ^3.0.2 specifier: ^3.0.2
version: 3.0.2 version: 3.0.2
idb-keyval:
specifier: 6.2.2
version: 6.2.2
jsondiffpatch: jsondiffpatch:
specifier: ^0.7.3 specifier: ^0.7.3
version: 0.7.3 version: 0.7.3
@@ -201,11 +198,11 @@ importers:
specifier: ^11.1.0 specifier: ^11.1.0
version: 11.1.0 version: 11.1.0
zod: zod:
specifier: ^4.0.5 specifier: ^4.0.10
version: 4.0.5 version: 4.0.10
zod-validation-error: zod-validation-error:
specifier: ^3.5.2 specifier: ^3.5.2
version: 3.5.3(zod@4.0.5) version: 3.5.3(zod@4.0.10)
devDependencies: devDependencies:
'@eslint/js': '@eslint/js':
specifier: ^9.31.0 specifier: ^9.31.0
@@ -411,6 +408,10 @@ packages:
resolution: {integrity: sha512-vbavdySgbTTrmFE+EsiqUTzlOr5bzlnJtUv9PynGCAKvfQqjIXbvFdumPM/GxMDfyuGMJaJAU6TO4zc1Jf1i8Q==} resolution: {integrity: sha512-vbavdySgbTTrmFE+EsiqUTzlOr5bzlnJtUv9PynGCAKvfQqjIXbvFdumPM/GxMDfyuGMJaJAU6TO4zc1Jf1i8Q==}
engines: {node: '>=6.9.0'} engines: {node: '>=6.9.0'}
'@babel/runtime@7.28.2':
resolution: {integrity: sha512-KHp2IflsnGywDjBWDkR9iEqiWSpc8GIi0lgTT3mOElT0PP1tG26P4tmFI2YvAdzgq9RGyoHZQEIEdZy6Ec5xCA==}
engines: {node: '>=6.9.0'}
'@babel/template@7.27.2': '@babel/template@7.27.2':
resolution: {integrity: sha512-LPDZ85aEJyYSd18/DkjNh4/y1ntkE5KwUHWTiqgRxruuZL2F1yuHligVHLvcHY2vMHXttKFpJn6LwfI7cw7ODw==} resolution: {integrity: sha512-LPDZ85aEJyYSd18/DkjNh4/y1ntkE5KwUHWTiqgRxruuZL2F1yuHligVHLvcHY2vMHXttKFpJn6LwfI7cw7ODw==}
engines: {node: '>=6.9.0'} engines: {node: '>=6.9.0'}
@@ -2771,9 +2772,6 @@ packages:
typescript: typescript:
optional: true optional: true
idb-keyval@6.2.2:
resolution: {integrity: sha512-yjD9nARJ/jb1g+CvD0tlhUHOrJ9Sy0P8T9MF3YaLlHnSRpwPfpTX0XIvpmw3gAJUmEu3FiICLBDPXVwyEvrleg==}
ieee754@1.2.1: ieee754@1.2.1:
resolution: {integrity: sha512-dcyqhDvX1C46lXZcVqCpK+FtMRQVdIMN6/Df5js2zouUsqG7I6sFxitIC+7KYK29KdXOLHdu9zL4sFnoVQnqaA==} resolution: {integrity: sha512-dcyqhDvX1C46lXZcVqCpK+FtMRQVdIMN6/Df5js2zouUsqG7I6sFxitIC+7KYK29KdXOLHdu9zL4sFnoVQnqaA==}
@@ -4511,8 +4509,8 @@ packages:
zod@3.25.76: zod@3.25.76:
resolution: {integrity: sha512-gzUt/qt81nXsFGKIFcC3YnfEAx5NkunCfnDlvuBSSFS02bcXu4Lmea0AFIUwbLWxWPx3d9p8S5QoaujKcNQxcQ==} resolution: {integrity: sha512-gzUt/qt81nXsFGKIFcC3YnfEAx5NkunCfnDlvuBSSFS02bcXu4Lmea0AFIUwbLWxWPx3d9p8S5QoaujKcNQxcQ==}
zod@4.0.5: zod@4.0.10:
resolution: {integrity: sha512-/5UuuRPStvHXu7RS+gmvRf4NXrNxpSllGwDnCBcJZtQsKrviYXm54yDGV2KYNLT5kq0lHGcl7lqWJLgSaG+tgA==} resolution: {integrity: sha512-3vB+UU3/VmLL2lvwcY/4RV2i9z/YU0DTV/tDuYjrwmx5WeJ7hwy+rGEEx8glHp6Yxw7ibRbKSaIFBgReRPe5KA==}
zustand@4.5.7: zustand@4.5.7:
resolution: {integrity: sha512-CHOUy7mu3lbD6o6LJLfllpjkzhHXSBlX8B9+qPddUsIfeF5S/UZ5q0kmCsnRqT1UHFQZchNFDDzMbQsuesHWlw==} resolution: {integrity: sha512-CHOUy7mu3lbD6o6LJLfllpjkzhHXSBlX8B9+qPddUsIfeF5S/UZ5q0kmCsnRqT1UHFQZchNFDDzMbQsuesHWlw==}
@@ -4633,6 +4631,8 @@ snapshots:
'@babel/runtime@7.27.6': {} '@babel/runtime@7.27.6': {}
'@babel/runtime@7.28.2': {}
'@babel/template@7.27.2': '@babel/template@7.27.2':
dependencies: dependencies:
'@babel/code-frame': 7.27.1 '@babel/code-frame': 7.27.1
@@ -5736,7 +5736,7 @@ snapshots:
'@testing-library/dom@10.4.0': '@testing-library/dom@10.4.0':
dependencies: dependencies:
'@babel/code-frame': 7.27.1 '@babel/code-frame': 7.27.1
'@babel/runtime': 7.27.6 '@babel/runtime': 7.28.2
'@types/aria-query': 5.0.4 '@types/aria-query': 5.0.4
aria-query: 5.3.0 aria-query: 5.3.0
chalk: 4.1.2 chalk: 4.1.2
@@ -7266,8 +7266,6 @@ snapshots:
optionalDependencies: optionalDependencies:
typescript: 5.8.3 typescript: 5.8.3
idb-keyval@6.2.2: {}
ieee754@1.2.1: {} ieee754@1.2.1: {}
ignore@5.3.2: {} ignore@5.3.2: {}
@@ -9062,13 +9060,13 @@ snapshots:
dependencies: dependencies:
zod: 3.25.76 zod: 3.25.76
zod-validation-error@3.5.3(zod@4.0.5): zod-validation-error@3.5.3(zod@4.0.10):
dependencies: dependencies:
zod: 4.0.5 zod: 4.0.10
zod@3.25.76: {} zod@3.25.76: {}
zod@4.0.5: {} zod@4.0.10: {}
zustand@4.5.7(@types/react@18.3.23)(immer@10.1.1)(react@18.3.1): zustand@4.5.7(@types/react@18.3.23)(immer@10.1.1)(react@18.3.1):
dependencies: dependencies:

View File

@@ -2,10 +2,10 @@ import { Box } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react'; import { useStore } from '@nanostores/react';
import { GlobalHookIsolator } from 'app/components/GlobalHookIsolator'; import { GlobalHookIsolator } from 'app/components/GlobalHookIsolator';
import { GlobalModalIsolator } from 'app/components/GlobalModalIsolator'; import { GlobalModalIsolator } from 'app/components/GlobalModalIsolator';
import { useClearStorage } from 'app/contexts/clear-storage-context';
import { $didStudioInit, type StudioInitAction } from 'app/hooks/useStudioInitAction'; import { $didStudioInit, type StudioInitAction } from 'app/hooks/useStudioInitAction';
import type { PartialAppConfig } from 'app/types/invokeai'; import type { PartialAppConfig } from 'app/types/invokeai';
import Loading from 'common/components/Loading/Loading'; import Loading from 'common/components/Loading/Loading';
import { useClearStorage } from 'common/hooks/useClearStorage';
import { AppContent } from 'features/ui/components/AppContent'; import { AppContent } from 'features/ui/components/AppContent';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { ErrorBoundary } from 'react-error-boundary'; import { ErrorBoundary } from 'react-error-boundary';

View File

@@ -1,10 +1,12 @@
import 'i18n'; import 'i18n';
import type { Middleware } from '@reduxjs/toolkit'; import type { Middleware } from '@reduxjs/toolkit';
import { ClearStorageProvider } from 'app/contexts/clear-storage-context';
import type { StudioInitAction } from 'app/hooks/useStudioInitAction'; import type { StudioInitAction } from 'app/hooks/useStudioInitAction';
import { $didStudioInit } from 'app/hooks/useStudioInitAction'; import { $didStudioInit } from 'app/hooks/useStudioInitAction';
import type { LoggingOverrides } from 'app/logging/logger'; import type { LoggingOverrides } from 'app/logging/logger';
import { $loggingOverrides, configureLogging } from 'app/logging/logger'; import { $loggingOverrides, configureLogging } from 'app/logging/logger';
import { buildStorageApi } from 'app/store/enhancers/reduxRemember/driver';
import { $accountSettingsLink } from 'app/store/nanostores/accountSettingsLink'; import { $accountSettingsLink } from 'app/store/nanostores/accountSettingsLink';
import { $authToken } from 'app/store/nanostores/authToken'; import { $authToken } from 'app/store/nanostores/authToken';
import { $baseUrl } from 'app/store/nanostores/baseUrl'; import { $baseUrl } from 'app/store/nanostores/baseUrl';
@@ -70,6 +72,14 @@ interface Props extends PropsWithChildren {
* If provided, overrides in-app navigation to the model manager * If provided, overrides in-app navigation to the model manager
*/ */
onClickGoToModelManager?: () => void; onClickGoToModelManager?: () => void;
storageConfig?: {
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
getItem: (key: string) => Promise<any>;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
setItem: (key: string, value: any) => Promise<any>;
clear: () => Promise<void>;
persistThrottle: number;
};
} }
const InvokeAIUI = ({ const InvokeAIUI = ({
@@ -96,6 +106,7 @@ const InvokeAIUI = ({
loggingOverrides, loggingOverrides,
onClickGoToModelManager, onClickGoToModelManager,
whatsNew, whatsNew,
storageConfig,
}: Props) => { }: Props) => {
useLayoutEffect(() => { useLayoutEffect(() => {
/* /*
@@ -308,9 +319,21 @@ const InvokeAIUI = ({
}; };
}, [isDebugging]); }, [isDebugging]);
const storage = useMemo(() => buildStorageApi(storageConfig), [storageConfig]);
useEffect(() => {
const storageCleanup = storage.registerListeners();
return () => {
storageCleanup();
};
}, [storage]);
const store = useMemo(() => { const store = useMemo(() => {
return createStore(projectId); return createStore({
}, [projectId]); driver: storage.reduxRememberDriver,
persistThrottle: storageConfig?.persistThrottle ?? 2000,
});
}, [storage.reduxRememberDriver, storageConfig?.persistThrottle]);
useEffect(() => { useEffect(() => {
$store.set(store); $store.set(store);
@@ -327,11 +350,13 @@ const InvokeAIUI = ({
return ( return (
<React.StrictMode> <React.StrictMode>
<Provider store={store}> <ClearStorageProvider value={storage.clearStorage}>
<React.Suspense fallback={<Loading />}> <Provider store={store}>
<App config={config} studioInitAction={studioInitAction} /> <React.Suspense fallback={<Loading />}>
</React.Suspense> <App config={config} studioInitAction={studioInitAction} />
</Provider> </React.Suspense>
</Provider>
</ClearStorageProvider>
</React.StrictMode> </React.StrictMode>
); );
}; };

View File

@@ -0,0 +1,10 @@
import { createContext, useContext } from 'react';
const ClearStorageContext = createContext<() => void>(() => {});
export const ClearStorageProvider = ClearStorageContext.Provider;
export const useClearStorage = () => {
const context = useContext(ClearStorageContext);
return context;
};

View File

@@ -1,3 +1,2 @@
export const STORAGE_PREFIX = '@@invokeai-';
export const EMPTY_ARRAY = []; export const EMPTY_ARRAY = [];
export const EMPTY_OBJECT = {}; export const EMPTY_OBJECT = {};

View File

@@ -1,40 +1,243 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import { logger } from 'app/logging/logger';
import { StorageError } from 'app/store/enhancers/reduxRemember/errors'; import { StorageError } from 'app/store/enhancers/reduxRemember/errors';
import { $projectId } from 'app/store/nanostores/projectId'; import { $projectId } from 'app/store/nanostores/projectId';
import type { UseStore } from 'idb-keyval'; import type { Driver as ReduxRememberDriver } from 'redux-remember';
import { clear, createStore as createIDBKeyValStore, get, set } from 'idb-keyval'; import { getBaseUrl } from 'services/api';
import { atom } from 'nanostores'; import { buildAppInfoUrl } from 'services/api/endpoints/appInfo';
import type { Driver } from 'redux-remember';
// Create a custom idb-keyval store (just needed to customize the name) const log = logger('system');
const $idbKeyValStore = atom<UseStore>(createIDBKeyValStore('invoke', 'invoke-store'));
export const clearIdbKeyValStore = () => { const buildOSSServerBackedDriver = (): {
clear($idbKeyValStore.get()); reduxRememberDriver: ReduxRememberDriver;
clearStorage: () => Promise<void>;
registerListeners: () => () => void;
} => {
// Persistence happens per slice. To track when persistence is in progress, maintain a ref count, incrementing
// it when a slice is being persisted and decrementing it when the persistence is done.
let persistRefCount = 0;
// Keep track of the last persisted state for each key to avoid unnecessary network requests.
//
// `redux-remember` persists individual slices of state, so we can implicity denylist a slice by not giving it a
// persist config.
//
// However, we may need to avoid persisting individual _fields_ of a slice. `redux-remember` does not provide a
// way to do this directly.
//
// To accomplish this, we add a layer of logic on top of the `redux-remember`. In the state serializer function
// provided to `redux-remember`, we can omit certain fields from the state that we do not want to persist. See
// the implementation in `store.ts` for this logic.
//
// This logic is unknown to `redux-remember`. When an omitted field changes, it will still attempt to persist the
// whole slice, even if the final, _serialized_ slice value is unchanged.
//
// To avoid unnecessary network requests, we keep track of the last persisted state for each key. If the value to
// be persisted is the same as the last persisted value, we can skip the network request.
const lastPersistedState = new Map<string, unknown>();
const getUrl = (key?: string) => {
const baseUrl = getBaseUrl();
const query: Record<string, string> = {};
if (key) {
query['key'] = key;
}
const path = buildAppInfoUrl('client_state', query);
const url = `${baseUrl}/${path}`;
return url;
};
const reduxRememberDriver: ReduxRememberDriver = {
getItem: async (key) => {
try {
const url = getUrl(key);
const res = await fetch(url, { method: 'GET' });
if (!res.ok) {
throw new Error(`Response status: ${res.status}`);
}
const text = await res.text();
if (!lastPersistedState.get(key)) {
lastPersistedState.set(key, text);
}
return JSON.parse(text);
} catch (originalError) {
throw new StorageError({
key,
projectId: $projectId.get(),
originalError,
});
}
},
setItem: async (key, value) => {
try {
persistRefCount++;
if (lastPersistedState.get(key) === value) {
log.trace(`Skipping persist for key "${key}" as value is unchanged.`);
return value;
}
const url = getUrl(key);
const headers = new Headers({
'Content-Type': 'application/json',
});
const res = await fetch(url, { method: 'POST', headers, body: value });
if (!res.ok) {
throw new Error(`Response status: ${res.status}`);
}
lastPersistedState.set(key, value);
return value;
} catch (originalError) {
throw new StorageError({
key,
value,
projectId: $projectId.get(),
originalError,
});
} finally {
persistRefCount--;
if (persistRefCount < 0) {
log.trace('Persist ref count is negative, resetting to 0');
persistRefCount = 0;
}
}
},
};
const clearStorage = async () => {
try {
persistRefCount++;
const url = getUrl();
const res = await fetch(url, { method: 'DELETE' });
if (!res.ok) {
throw new Error(`Response status: ${res.status}`);
}
} catch {
log.error('Failed to reset client state');
} finally {
persistRefCount--;
lastPersistedState.clear();
if (persistRefCount < 0) {
log.trace('Persist ref count is negative, resetting to 0');
persistRefCount = 0;
}
}
};
const registerListeners = () => {
const onBeforeUnload = (e: BeforeUnloadEvent) => {
if (persistRefCount > 0) {
e.preventDefault();
}
};
window.addEventListener('beforeunload', onBeforeUnload);
return () => {
window.removeEventListener('beforeunload', onBeforeUnload);
};
};
return { reduxRememberDriver, clearStorage, registerListeners };
}; };
// Create redux-remember driver, wrapping idb-keyval const buildCustomDriver = (api: {
export const idbKeyValDriver: Driver = { getItem: (key: string) => Promise<any>;
getItem: (key) => { setItem: (key: string, value: any) => Promise<any>;
clear: () => Promise<void>;
}): {
reduxRememberDriver: ReduxRememberDriver;
clearStorage: () => Promise<void>;
registerListeners: () => () => void;
} => {
// See the comment in `buildOSSServerBackedDriver` for an explanation of this variable.
let persistRefCount = 0;
// See the comment in `buildOSSServerBackedDriver` for an explanation of this variable.
const lastPersistedState = new Map<string, unknown>();
const reduxRememberDriver: ReduxRememberDriver = {
getItem: async (key) => {
try {
log.trace(`Getting client state for key "${key}"`);
return await api.getItem(key);
} catch (originalError) {
throw new StorageError({
key,
projectId: $projectId.get(),
originalError,
});
}
},
setItem: async (key, value) => {
try {
persistRefCount++;
if (lastPersistedState.get(key) === value) {
log.trace(`Skipping setting client state for key "${key}" as value is unchanged`);
return value;
}
log.trace(`Setting client state for key "${key}", ${value}`);
await api.setItem(key, value);
lastPersistedState.set(key, value);
return value;
} catch (originalError) {
throw new StorageError({
key,
value,
projectId: $projectId.get(),
originalError,
});
} finally {
persistRefCount--;
if (persistRefCount < 0) {
log.trace('Persist ref count is negative, resetting to 0');
persistRefCount = 0;
}
}
},
};
const clearStorage = async () => {
try { try {
return get(key, $idbKeyValStore.get()); persistRefCount++;
} catch (originalError) { log.trace('Clearing client state');
throw new StorageError({ await api.clear();
key, } catch {
projectId: $projectId.get(), log.error('Failed to clear client state');
originalError, } finally {
}); persistRefCount--;
lastPersistedState.clear();
if (persistRefCount < 0) {
log.trace('Persist ref count is negative, resetting to 0');
persistRefCount = 0;
}
} }
}, };
setItem: (key, value) => {
try { const registerListeners = () => {
return set(key, value, $idbKeyValStore.get()); const onBeforeUnload = (e: BeforeUnloadEvent) => {
} catch (originalError) { if (persistRefCount > 0) {
throw new StorageError({ e.preventDefault();
key, }
value, };
projectId: $projectId.get(), window.addEventListener('beforeunload', onBeforeUnload);
originalError,
}); return () => {
} window.removeEventListener('beforeunload', onBeforeUnload);
}, };
};
return { reduxRememberDriver, clearStorage, registerListeners };
};
export const buildStorageApi = (api?: {
getItem: (key: string) => Promise<any>;
setItem: (key: string, value: any) => Promise<any>;
clear: () => Promise<void>;
}) => {
if (api) {
return buildCustomDriver(api);
} else {
return buildOSSServerBackedDriver();
}
}; };

View File

@@ -1,73 +0,0 @@
import type { TypedStartListening } from '@reduxjs/toolkit';
import { addListener, createListenerMiddleware } from '@reduxjs/toolkit';
import { addAdHocPostProcessingRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/addAdHocPostProcessingRequestedListener';
import { addAnyEnqueuedListener } from 'app/store/middleware/listenerMiddleware/listeners/anyEnqueued';
import { addAppConfigReceivedListener } from 'app/store/middleware/listenerMiddleware/listeners/appConfigReceived';
import { addAppStartedListener } from 'app/store/middleware/listenerMiddleware/listeners/appStarted';
import { addBatchEnqueuedListener } from 'app/store/middleware/listenerMiddleware/listeners/batchEnqueued';
import { addDeleteBoardAndImagesFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/boardAndImagesDeleted';
import { addBoardIdSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/boardIdSelected';
import { addBulkDownloadListeners } from 'app/store/middleware/listenerMiddleware/listeners/bulkDownload';
import { addGetOpenAPISchemaListener } from 'app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema';
import { addImageAddedToBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageAddedToBoard';
import { addImageRemovedFromBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageRemovedFromBoard';
import { addImageUploadedFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageUploaded';
import { addModelSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelSelected';
import { addModelsLoadedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelsLoaded';
import { addSetDefaultSettingsListener } from 'app/store/middleware/listenerMiddleware/listeners/setDefaultSettings';
import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketConnected';
import type { AppDispatch, RootState } from 'app/store/store';
import { addArchivedOrDeletedBoardListener } from './listeners/addArchivedOrDeletedBoardListener';
export const listenerMiddleware = createListenerMiddleware();
export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
const startAppListening = listenerMiddleware.startListening as AppStartListening;
export const addAppListener = addListener.withTypes<RootState, AppDispatch>();
/**
* The RTK listener middleware is a lightweight alternative sagas/observables.
*
* Most side effect logic should live in a listener.
*/
// Image uploaded
addImageUploadedFulfilledListener(startAppListening);
// Image deleted
addDeleteBoardAndImagesFulfilledListener(startAppListening);
// User Invoked
addAnyEnqueuedListener(startAppListening);
addBatchEnqueuedListener(startAppListening);
// Socket.IO
addSocketConnectedEventListener(startAppListening);
// Gallery bulk download
addBulkDownloadListeners(startAppListening);
// Boards
addImageAddedToBoardFulfilledListener(startAppListening);
addImageRemovedFromBoardFulfilledListener(startAppListening);
addBoardIdSelectedListener(startAppListening);
addArchivedOrDeletedBoardListener(startAppListening);
// Node schemas
addGetOpenAPISchemaListener(startAppListening);
// Models
addModelSelectedListener(startAppListening);
// app startup
addAppStartedListener(startAppListening);
addModelsLoadedListener(startAppListening);
addAppConfigReceivedListener(startAppListening);
// Ad-hoc upscale workflwo
addAdHocPostProcessingRequestedListener(startAppListening);
addSetDefaultSettingsListener(startAppListening);

View File

@@ -1,6 +1,6 @@
import { createAction } from '@reduxjs/toolkit'; import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/store';
import { buildAdHocPostProcessingGraph } from 'features/nodes/util/graph/buildAdHocPostProcessingGraph'; import { buildAdHocPostProcessingGraph } from 'features/nodes/util/graph/buildAdHocPostProcessingGraph';
import { toast } from 'features/toast/toast'; import { toast } from 'features/toast/toast';
import { t } from 'i18next'; import { t } from 'i18next';

View File

@@ -1,5 +1,5 @@
import { isAnyOf } from '@reduxjs/toolkit'; import { isAnyOf } from '@reduxjs/toolkit';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/store';
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors'; import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
import { import {
autoAddBoardIdChanged, autoAddBoardIdChanged,

View File

@@ -1,4 +1,4 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/store';
import { queueApi, selectQueueStatus } from 'services/api/endpoints/queue'; import { queueApi, selectQueueStatus } from 'services/api/endpoints/queue';
export const addAnyEnqueuedListener = (startAppListening: AppStartListening) => { export const addAnyEnqueuedListener = (startAppListening: AppStartListening) => {

View File

@@ -1,4 +1,4 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/store';
import { setInfillMethod } from 'features/controlLayers/store/paramsSlice'; import { setInfillMethod } from 'features/controlLayers/store/paramsSlice';
import { shouldUseNSFWCheckerChanged, shouldUseWatermarkerChanged } from 'features/system/store/systemSlice'; import { shouldUseNSFWCheckerChanged, shouldUseWatermarkerChanged } from 'features/system/store/systemSlice';
import { appInfoApi } from 'services/api/endpoints/appInfo'; import { appInfoApi } from 'services/api/endpoints/appInfo';

View File

@@ -1,5 +1,5 @@
import { createAction } from '@reduxjs/toolkit'; import { createAction } from '@reduxjs/toolkit';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/store';
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors'; import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
import { imageSelected } from 'features/gallery/store/gallerySlice'; import { imageSelected } from 'features/gallery/store/gallerySlice';
import { imagesApi } from 'services/api/endpoints/images'; import { imagesApi } from 'services/api/endpoints/images';

View File

@@ -1,5 +1,5 @@
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/store';
import { truncate } from 'es-toolkit/compat'; import { truncate } from 'es-toolkit/compat';
import { zPydanticValidationError } from 'features/system/store/zodSchemas'; import { zPydanticValidationError } from 'features/system/store/zodSchemas';
import { toast } from 'features/toast/toast'; import { toast } from 'features/toast/toast';

View File

@@ -1,4 +1,4 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/store';
import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice'; import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors'; import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { getImageUsage } from 'features/deleteImageModal/store/state'; import { getImageUsage } from 'features/deleteImageModal/store/state';

View File

@@ -1,5 +1,5 @@
import { isAnyOf } from '@reduxjs/toolkit'; import { isAnyOf } from '@reduxjs/toolkit';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/store';
import { selectGetImageNamesQueryArgs, selectSelectedBoardId } from 'features/gallery/store/gallerySelectors'; import { selectGetImageNamesQueryArgs, selectSelectedBoardId } from 'features/gallery/store/gallerySelectors';
import { boardIdSelected, galleryViewChanged, imageSelected } from 'features/gallery/store/gallerySlice'; import { boardIdSelected, galleryViewChanged, imageSelected } from 'features/gallery/store/gallerySlice';
import { imagesApi } from 'services/api/endpoints/images'; import { imagesApi } from 'services/api/endpoints/images';

View File

@@ -1,5 +1,5 @@
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/store';
import { toast } from 'features/toast/toast'; import { toast } from 'features/toast/toast';
import { t } from 'i18next'; import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images'; import { imagesApi } from 'services/api/endpoints/images';

View File

@@ -1,5 +1,5 @@
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/store';
import { parseify } from 'common/util/serialize'; import { parseify } from 'common/util/serialize';
import { size } from 'es-toolkit/compat'; import { size } from 'es-toolkit/compat';
import { $templates } from 'features/nodes/store/nodesSlice'; import { $templates } from 'features/nodes/store/nodesSlice';

View File

@@ -1,5 +1,5 @@
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/store';
import { imagesApi } from 'services/api/endpoints/images'; import { imagesApi } from 'services/api/endpoints/images';
const log = logger('gallery'); const log = logger('gallery');

View File

@@ -1,5 +1,5 @@
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/store';
import { imagesApi } from 'services/api/endpoints/images'; import { imagesApi } from 'services/api/endpoints/images';
const log = logger('gallery'); const log = logger('gallery');

View File

@@ -1,7 +1,6 @@
import { isAnyOf } from '@reduxjs/toolkit'; import { isAnyOf } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening, RootState } from 'app/store/store';
import type { RootState } from 'app/store/store';
import { omit } from 'es-toolkit/compat'; import { omit } from 'es-toolkit/compat';
import { imageUploadedClientSide } from 'features/gallery/store/actions'; import { imageUploadedClientSide } from 'features/gallery/store/actions';
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors'; import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';

View File

@@ -1,5 +1,5 @@
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/store';
import { bboxSyncedToOptimalDimension, rgRefImageModelChanged } from 'features/controlLayers/store/canvasSlice'; import { bboxSyncedToOptimalDimension, rgRefImageModelChanged } from 'features/controlLayers/store/canvasSlice';
import { buildSelectIsStaging, selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice'; import { buildSelectIsStaging, selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { loraDeleted } from 'features/controlLayers/store/lorasSlice'; import { loraDeleted } from 'features/controlLayers/store/lorasSlice';

View File

@@ -1,6 +1,5 @@
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppDispatch, AppStartListening, RootState } from 'app/store/store';
import type { AppDispatch, RootState } from 'app/store/store';
import { controlLayerModelChanged, rgRefImageModelChanged } from 'features/controlLayers/store/canvasSlice'; import { controlLayerModelChanged, rgRefImageModelChanged } from 'features/controlLayers/store/canvasSlice';
import { loraDeleted } from 'features/controlLayers/store/lorasSlice'; import { loraDeleted } from 'features/controlLayers/store/lorasSlice';
import { import {

View File

@@ -1,4 +1,4 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/store';
import { isNil } from 'es-toolkit'; import { isNil } from 'es-toolkit';
import { bboxHeightChanged, bboxWidthChanged } from 'features/controlLayers/store/canvasSlice'; import { bboxHeightChanged, bboxWidthChanged } from 'features/controlLayers/store/canvasSlice';
import { buildSelectIsStaging, selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice'; import { buildSelectIsStaging, selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';

View File

@@ -1,8 +1,8 @@
import { objectEquals } from '@observ33r/object-equals'; import { objectEquals } from '@observ33r/object-equals';
import { createAction } from '@reduxjs/toolkit'; import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { $baseUrl } from 'app/store/nanostores/baseUrl'; import { $baseUrl } from 'app/store/nanostores/baseUrl';
import type { AppStartListening } from 'app/store/store';
import { atom } from 'nanostores'; import { atom } from 'nanostores';
import { api } from 'services/api'; import { api } from 'services/api';
import { modelsApi } from 'services/api/endpoints/models'; import { modelsApi } from 'services/api/endpoints/models';

View File

@@ -1,35 +1,46 @@
import type { ThunkDispatch, UnknownAction } from '@reduxjs/toolkit'; import type { ThunkDispatch, TypedStartListening, UnknownAction } from '@reduxjs/toolkit';
import { autoBatchEnhancer, combineReducers, configureStore } from '@reduxjs/toolkit'; import { addListener, combineReducers, configureStore, createListenerMiddleware } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import { idbKeyValDriver } from 'app/store/enhancers/reduxRemember/driver';
import { errorHandler } from 'app/store/enhancers/reduxRemember/errors'; import { errorHandler } from 'app/store/enhancers/reduxRemember/errors';
import { addAdHocPostProcessingRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/addAdHocPostProcessingRequestedListener';
import { addAnyEnqueuedListener } from 'app/store/middleware/listenerMiddleware/listeners/anyEnqueued';
import { addAppConfigReceivedListener } from 'app/store/middleware/listenerMiddleware/listeners/appConfigReceived';
import { addAppStartedListener } from 'app/store/middleware/listenerMiddleware/listeners/appStarted';
import { addBatchEnqueuedListener } from 'app/store/middleware/listenerMiddleware/listeners/batchEnqueued';
import { addDeleteBoardAndImagesFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/boardAndImagesDeleted';
import { addBoardIdSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/boardIdSelected';
import { addBulkDownloadListeners } from 'app/store/middleware/listenerMiddleware/listeners/bulkDownload';
import { addGetOpenAPISchemaListener } from 'app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema';
import { addImageAddedToBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageAddedToBoard';
import { addImageRemovedFromBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageRemovedFromBoard';
import { addModelSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelSelected';
import { addModelsLoadedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelsLoaded';
import { addSetDefaultSettingsListener } from 'app/store/middleware/listenerMiddleware/listeners/setDefaultSettings';
import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketConnected';
import { deepClone } from 'common/util/deepClone'; import { deepClone } from 'common/util/deepClone';
import { keys, mergeWith, omit, pick } from 'es-toolkit/compat'; import { keys, mergeWith, omit, pick } from 'es-toolkit/compat';
import { changeBoardModalSlice } from 'features/changeBoardModal/store/slice'; import { changeBoardModalSliceConfig } from 'features/changeBoardModal/store/slice';
import { canvasSettingsPersistConfig, canvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice'; import { canvasSettingsSliceConfig } from 'features/controlLayers/store/canvasSettingsSlice';
import { canvasPersistConfig, canvasSlice, canvasUndoableConfig } from 'features/controlLayers/store/canvasSlice'; import { canvasSliceConfig } from 'features/controlLayers/store/canvasSlice';
import { import { canvasSessionSliceConfig } from 'features/controlLayers/store/canvasStagingAreaSlice';
canvasSessionSlice, import { lorasSliceConfig } from 'features/controlLayers/store/lorasSlice';
canvasStagingAreaPersistConfig, import { paramsSliceConfig } from 'features/controlLayers/store/paramsSlice';
} from 'features/controlLayers/store/canvasStagingAreaSlice'; import { refImagesSliceConfig } from 'features/controlLayers/store/refImagesSlice';
import { lorasPersistConfig, lorasSlice } from 'features/controlLayers/store/lorasSlice'; import { dynamicPromptsSliceConfig } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import { paramsPersistConfig, paramsSlice } from 'features/controlLayers/store/paramsSlice'; import { gallerySliceConfig } from 'features/gallery/store/gallerySlice';
import { refImagesPersistConfig, refImagesSlice } from 'features/controlLayers/store/refImagesSlice'; import { modelManagerSliceConfig } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { dynamicPromptsPersistConfig, dynamicPromptsSlice } from 'features/dynamicPrompts/store/dynamicPromptsSlice'; import { nodesSliceConfig } from 'features/nodes/store/nodesSlice';
import { galleryPersistConfig, gallerySlice } from 'features/gallery/store/gallerySlice'; import { workflowLibrarySliceConfig } from 'features/nodes/store/workflowLibrarySlice';
import { modelManagerV2PersistConfig, modelManagerV2Slice } from 'features/modelManagerV2/store/modelManagerV2Slice'; import { workflowSettingsSliceConfig } from 'features/nodes/store/workflowSettingsSlice';
import { nodesPersistConfig, nodesSlice, nodesUndoableConfig } from 'features/nodes/store/nodesSlice'; import { upscaleSliceConfig } from 'features/parameters/store/upscaleSlice';
import { workflowLibraryPersistConfig, workflowLibrarySlice } from 'features/nodes/store/workflowLibrarySlice'; import { queueSliceConfig } from 'features/queue/store/queueSlice';
import { workflowSettingsPersistConfig, workflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice'; import { stylePresetSliceConfig } from 'features/stylePresets/store/stylePresetSlice';
import { upscalePersistConfig, upscaleSlice } from 'features/parameters/store/upscaleSlice'; import { configSliceConfig } from 'features/system/store/configSlice';
import { queueSlice } from 'features/queue/store/queueSlice'; import { systemSliceConfig } from 'features/system/store/systemSlice';
import { stylePresetPersistConfig, stylePresetSlice } from 'features/stylePresets/store/stylePresetSlice'; import { uiSliceConfig } from 'features/ui/store/uiSlice';
import { configSlice } from 'features/system/store/configSlice';
import { systemPersistConfig, systemSlice } from 'features/system/store/systemSlice';
import { uiPersistConfig, uiSlice } from 'features/ui/store/uiSlice';
import { diff } from 'jsondiffpatch'; import { diff } from 'jsondiffpatch';
import dynamicMiddlewares from 'redux-dynamic-middlewares'; import dynamicMiddlewares from 'redux-dynamic-middlewares';
import type { SerializeFunction, UnserializeFunction } from 'redux-remember'; import type { Driver, SerializeFunction, UnserializeFunction } from 'redux-remember';
import { rememberEnhancer, rememberReducer } from 'redux-remember'; import { rememberEnhancer, rememberReducer } from 'redux-remember';
import undoable, { newHistory } from 'redux-undo'; import undoable, { newHistory } from 'redux-undo';
import { serializeError } from 'serialize-error'; import { serializeError } from 'serialize-error';
@@ -37,123 +48,116 @@ import { api } from 'services/api';
import { authToastMiddleware } from 'services/api/authToastMiddleware'; import { authToastMiddleware } from 'services/api/authToastMiddleware';
import type { JsonObject } from 'type-fest'; import type { JsonObject } from 'type-fest';
import { STORAGE_PREFIX } from './constants';
import { actionSanitizer } from './middleware/devtools/actionSanitizer'; import { actionSanitizer } from './middleware/devtools/actionSanitizer';
import { actionsDenylist } from './middleware/devtools/actionsDenylist'; import { actionsDenylist } from './middleware/devtools/actionsDenylist';
import { stateSanitizer } from './middleware/devtools/stateSanitizer'; import { stateSanitizer } from './middleware/devtools/stateSanitizer';
import { listenerMiddleware } from './middleware/listenerMiddleware'; import { addArchivedOrDeletedBoardListener } from './middleware/listenerMiddleware/listeners/addArchivedOrDeletedBoardListener';
import { addImageUploadedFulfilledListener } from './middleware/listenerMiddleware/listeners/imageUploaded';
export const listenerMiddleware = createListenerMiddleware();
const log = logger('system'); const log = logger('system');
const allReducers = { // When adding a slice, add the config to the SLICE_CONFIGS object below, then add the reducer to ALL_REDUCERS.
[api.reducerPath]: api.reducer, const SLICE_CONFIGS = {
[gallerySlice.name]: gallerySlice.reducer, [canvasSessionSliceConfig.slice.reducerPath]: canvasSessionSliceConfig,
[nodesSlice.name]: undoable(nodesSlice.reducer, nodesUndoableConfig), [canvasSettingsSliceConfig.slice.reducerPath]: canvasSettingsSliceConfig,
[systemSlice.name]: systemSlice.reducer, [canvasSliceConfig.slice.reducerPath]: canvasSliceConfig,
[configSlice.name]: configSlice.reducer, [changeBoardModalSliceConfig.slice.reducerPath]: changeBoardModalSliceConfig,
[uiSlice.name]: uiSlice.reducer, [configSliceConfig.slice.reducerPath]: configSliceConfig,
[dynamicPromptsSlice.name]: dynamicPromptsSlice.reducer, [dynamicPromptsSliceConfig.slice.reducerPath]: dynamicPromptsSliceConfig,
[changeBoardModalSlice.name]: changeBoardModalSlice.reducer, [gallerySliceConfig.slice.reducerPath]: gallerySliceConfig,
[modelManagerV2Slice.name]: modelManagerV2Slice.reducer, [lorasSliceConfig.slice.reducerPath]: lorasSliceConfig,
[queueSlice.name]: queueSlice.reducer, [modelManagerSliceConfig.slice.reducerPath]: modelManagerSliceConfig,
[canvasSlice.name]: undoable(canvasSlice.reducer, canvasUndoableConfig), [nodesSliceConfig.slice.reducerPath]: nodesSliceConfig,
[workflowSettingsSlice.name]: workflowSettingsSlice.reducer, [paramsSliceConfig.slice.reducerPath]: paramsSliceConfig,
[upscaleSlice.name]: upscaleSlice.reducer, [queueSliceConfig.slice.reducerPath]: queueSliceConfig,
[stylePresetSlice.name]: stylePresetSlice.reducer, [refImagesSliceConfig.slice.reducerPath]: refImagesSliceConfig,
[paramsSlice.name]: paramsSlice.reducer, [stylePresetSliceConfig.slice.reducerPath]: stylePresetSliceConfig,
[canvasSettingsSlice.name]: canvasSettingsSlice.reducer, [systemSliceConfig.slice.reducerPath]: systemSliceConfig,
[canvasSessionSlice.name]: canvasSessionSlice.reducer, [uiSliceConfig.slice.reducerPath]: uiSliceConfig,
[lorasSlice.name]: lorasSlice.reducer, [upscaleSliceConfig.slice.reducerPath]: upscaleSliceConfig,
[workflowLibrarySlice.name]: workflowLibrarySlice.reducer, [workflowLibrarySliceConfig.slice.reducerPath]: workflowLibrarySliceConfig,
[refImagesSlice.name]: refImagesSlice.reducer, [workflowSettingsSliceConfig.slice.reducerPath]: workflowSettingsSliceConfig,
}; };
const rootReducer = combineReducers(allReducers); // TS makes it really hard to dynamically create this object :/ so it's just hardcoded here.
// Remember to wrap undoable reducers in `undoable()`!
const ALL_REDUCERS = {
[api.reducerPath]: api.reducer,
[canvasSessionSliceConfig.slice.reducerPath]: canvasSessionSliceConfig.slice.reducer,
[canvasSettingsSliceConfig.slice.reducerPath]: canvasSettingsSliceConfig.slice.reducer,
// Undoable!
[canvasSliceConfig.slice.reducerPath]: undoable(
canvasSliceConfig.slice.reducer,
canvasSliceConfig.undoableConfig?.reduxUndoOptions
),
[changeBoardModalSliceConfig.slice.reducerPath]: changeBoardModalSliceConfig.slice.reducer,
[configSliceConfig.slice.reducerPath]: configSliceConfig.slice.reducer,
[dynamicPromptsSliceConfig.slice.reducerPath]: dynamicPromptsSliceConfig.slice.reducer,
[gallerySliceConfig.slice.reducerPath]: gallerySliceConfig.slice.reducer,
[lorasSliceConfig.slice.reducerPath]: lorasSliceConfig.slice.reducer,
[modelManagerSliceConfig.slice.reducerPath]: modelManagerSliceConfig.slice.reducer,
// Undoable!
[nodesSliceConfig.slice.reducerPath]: undoable(
nodesSliceConfig.slice.reducer,
nodesSliceConfig.undoableConfig?.reduxUndoOptions
),
[paramsSliceConfig.slice.reducerPath]: paramsSliceConfig.slice.reducer,
[queueSliceConfig.slice.reducerPath]: queueSliceConfig.slice.reducer,
[refImagesSliceConfig.slice.reducerPath]: refImagesSliceConfig.slice.reducer,
[stylePresetSliceConfig.slice.reducerPath]: stylePresetSliceConfig.slice.reducer,
[systemSliceConfig.slice.reducerPath]: systemSliceConfig.slice.reducer,
[uiSliceConfig.slice.reducerPath]: uiSliceConfig.slice.reducer,
[upscaleSliceConfig.slice.reducerPath]: upscaleSliceConfig.slice.reducer,
[workflowLibrarySliceConfig.slice.reducerPath]: workflowLibrarySliceConfig.slice.reducer,
[workflowSettingsSliceConfig.slice.reducerPath]: workflowSettingsSliceConfig.slice.reducer,
};
const rootReducer = combineReducers(ALL_REDUCERS);
const rememberedRootReducer = rememberReducer(rootReducer); const rememberedRootReducer = rememberReducer(rootReducer);
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
export type PersistConfig<T = any> = {
/**
* The name of the slice.
*/
name: keyof typeof allReducers;
/**
* The initial state of the slice.
*/
initialState: T;
/**
* Migrate the state to the current version during rehydration.
* @param state The rehydrated state.
* @returns A correctly-shaped state.
*/
migrate: (state: unknown) => T;
/**
* Keys to omit from the persisted state.
*/
persistDenylist: (keyof T)[];
};
const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = {
[galleryPersistConfig.name]: galleryPersistConfig,
[nodesPersistConfig.name]: nodesPersistConfig,
[systemPersistConfig.name]: systemPersistConfig,
[uiPersistConfig.name]: uiPersistConfig,
[dynamicPromptsPersistConfig.name]: dynamicPromptsPersistConfig,
[modelManagerV2PersistConfig.name]: modelManagerV2PersistConfig,
[canvasPersistConfig.name]: canvasPersistConfig,
[workflowSettingsPersistConfig.name]: workflowSettingsPersistConfig,
[upscalePersistConfig.name]: upscalePersistConfig,
[stylePresetPersistConfig.name]: stylePresetPersistConfig,
[paramsPersistConfig.name]: paramsPersistConfig,
[canvasSettingsPersistConfig.name]: canvasSettingsPersistConfig,
[canvasStagingAreaPersistConfig.name]: canvasStagingAreaPersistConfig,
[lorasPersistConfig.name]: lorasPersistConfig,
[workflowLibraryPersistConfig.name]: workflowLibraryPersistConfig,
[refImagesSlice.name]: refImagesPersistConfig,
};
const unserialize: UnserializeFunction = (data, key) => { const unserialize: UnserializeFunction = (data, key) => {
const persistConfig = persistConfigs[key as keyof typeof persistConfigs]; const sliceConfig = SLICE_CONFIGS[key as keyof typeof SLICE_CONFIGS];
if (!persistConfig) { if (!sliceConfig?.persistConfig) {
throw new Error(`No persist config for slice "${key}"`); throw new Error(`No persist config for slice "${key}"`);
} }
const { getInitialState, persistConfig, undoableConfig } = sliceConfig;
let state; let state;
try { try {
const { initialState, migrate } = persistConfig; const initialState = getInitialState();
const parsed = JSON.parse(data);
// strip out old keys // strip out old keys
const stripped = pick(deepClone(parsed), keys(initialState)); const stripped = pick(deepClone(data), keys(initialState));
// run (additive) migrations
const migrated = migrate(stripped);
/* /*
* Merge in initial state as default values, covering any missing keys. You might be tempted to use _.defaultsDeep, * Merge in initial state as default values, covering any missing keys. You might be tempted to use _.defaultsDeep,
* but that merges arrays by index and partial objects by key. Using an identity function as the customizer results * but that merges arrays by index and partial objects by key. Using an identity function as the customizer results
* in behaviour like defaultsDeep, but doesn't overwrite any values that are not undefined in the migrated state. * in behaviour like defaultsDeep, but doesn't overwrite any values that are not undefined in the migrated state.
*/ */
const transformed = mergeWith(migrated, initialState, (objVal) => objVal); const unPersistDenylisted = mergeWith(stripped, initialState, (objVal) => objVal);
// run (additive) migrations
const migrated = persistConfig.migrate(unPersistDenylisted);
log.debug( log.debug(
{ {
persistedData: parsed, persistedData: data as JsonObject,
rehydratedData: transformed, rehydratedData: migrated as JsonObject,
diff: diff(parsed, transformed) as JsonObject, // this is always serializable diff: diff(data, migrated) as JsonObject,
}, },
`Rehydrated slice "${key}"` `Rehydrated slice "${key}"`
); );
state = transformed; state = migrated;
} catch (err) { } catch (err) {
log.warn( log.warn(
{ error: serializeError(err as Error) }, { error: serializeError(err as Error) },
`Error rehydrating slice "${key}", falling back to default initial state` `Error rehydrating slice "${key}", falling back to default initial state`
); );
state = persistConfig.initialState; state = getInitialState();
} }
// If the slice is undoable, we need to wrap it in a new history - only nodes and canvas are undoable at the moment. // Undoable slices must be wrapped in a history!
// TODO(psyche): make this automatic & remove the hard-coding for specific slices. if (undoableConfig) {
if (key === nodesSlice.name || key === canvasSlice.name) {
return newHistory([], state, []); return newHistory([], state, []);
} else { } else {
return state; return state;
@@ -161,21 +165,30 @@ const unserialize: UnserializeFunction = (data, key) => {
}; };
const serialize: SerializeFunction = (data, key) => { const serialize: SerializeFunction = (data, key) => {
const persistConfig = persistConfigs[key as keyof typeof persistConfigs]; const sliceConfig = SLICE_CONFIGS[key as keyof typeof SLICE_CONFIGS];
if (!persistConfig) { if (!sliceConfig?.persistConfig) {
throw new Error(`No persist config for slice "${key}"`); throw new Error(`No persist config for slice "${key}"`);
} }
// Heuristic to determine if the slice is undoable - could just hardcode it in the persistConfig
const isUndoable = 'present' in data && 'past' in data && 'future' in data && '_latestUnfiltered' in data; const result = omit(
const result = omit(isUndoable ? data.present : data, persistConfig.persistDenylist); sliceConfig.undoableConfig ? data.present : data,
sliceConfig.persistConfig.persistDenylist ?? []
);
return JSON.stringify(result); return JSON.stringify(result);
}; };
export const createStore = (uniqueStoreKey?: string, persist = true) => const PERSISTED_KEYS = Object.values(SLICE_CONFIGS)
.filter((sliceConfig) => !!sliceConfig.persistConfig)
.map((sliceConfig) => sliceConfig.slice.reducerPath);
export const createStore = (reduxRememberOptions: { driver: Driver; persistThrottle: number }) =>
configureStore({ configureStore({
reducer: rememberedRootReducer, reducer: rememberedRootReducer,
middleware: (getDefaultMiddleware) => middleware: (getDefaultMiddleware) =>
getDefaultMiddleware({ getDefaultMiddleware({
// serializableCheck: false,
// immutableCheck: false,
serializableCheck: import.meta.env.MODE === 'development', serializableCheck: import.meta.env.MODE === 'development',
immutableCheck: import.meta.env.MODE === 'development', immutableCheck: import.meta.env.MODE === 'development',
}) })
@@ -185,19 +198,16 @@ export const createStore = (uniqueStoreKey?: string, persist = true) =>
// .concat(getDebugLoggerMiddleware()) // .concat(getDebugLoggerMiddleware())
.prepend(listenerMiddleware.middleware), .prepend(listenerMiddleware.middleware),
enhancers: (getDefaultEnhancers) => { enhancers: (getDefaultEnhancers) => {
const _enhancers = getDefaultEnhancers().concat(autoBatchEnhancer()); const enhancers = getDefaultEnhancers();
if (persist) { return enhancers.prepend(
_enhancers.push( rememberEnhancer(reduxRememberOptions.driver, PERSISTED_KEYS, {
rememberEnhancer(idbKeyValDriver, keys(persistConfigs), { persistThrottle: reduxRememberOptions.persistThrottle,
persistDebounce: 300, serialize,
serialize, unserialize,
unserialize, prefix: '',
prefix: uniqueStoreKey ? `${STORAGE_PREFIX}${uniqueStoreKey}-` : STORAGE_PREFIX, errorHandler,
errorHandler, })
}) );
);
}
return _enhancers;
}, },
devTools: { devTools: {
actionSanitizer, actionSanitizer,
@@ -214,7 +224,48 @@ export const createStore = (uniqueStoreKey?: string, persist = true) =>
export type AppStore = ReturnType<typeof createStore>; export type AppStore = ReturnType<typeof createStore>;
export type RootState = ReturnType<AppStore['getState']>; export type RootState = ReturnType<AppStore['getState']>;
// eslint-disable-next-line @typescript-eslint/no-explicit-any /* eslint-disable-next-line @typescript-eslint/no-explicit-any */
export type AppThunkDispatch = ThunkDispatch<RootState, any, UnknownAction>; export type AppThunkDispatch = ThunkDispatch<RootState, any, UnknownAction>;
export type AppDispatch = ReturnType<typeof createStore>['dispatch']; export type AppDispatch = ReturnType<typeof createStore>['dispatch'];
export type AppGetState = ReturnType<typeof createStore>['getState']; export type AppGetState = ReturnType<typeof createStore>['getState'];
export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
export const addAppListener = addListener.withTypes<RootState, AppDispatch>();
const startAppListening = listenerMiddleware.startListening as AppStartListening;
addImageUploadedFulfilledListener(startAppListening);
// Image deleted
addDeleteBoardAndImagesFulfilledListener(startAppListening);
// User Invoked
addAnyEnqueuedListener(startAppListening);
addBatchEnqueuedListener(startAppListening);
// Socket.IO
addSocketConnectedEventListener(startAppListening);
// Gallery bulk download
addBulkDownloadListeners(startAppListening);
// Boards
addImageAddedToBoardFulfilledListener(startAppListening);
addImageRemovedFromBoardFulfilledListener(startAppListening);
addBoardIdSelectedListener(startAppListening);
addArchivedOrDeletedBoardListener(startAppListening);
// Node schemas
addGetOpenAPISchemaListener(startAppListening);
// Models
addModelSelectedListener(startAppListening);
// app startup
addAppStartedListener(startAppListening);
addModelsLoadedListener(startAppListening);
addAppConfigReceivedListener(startAppListening);
// Ad-hoc upscale workflwo
addAdHocPostProcessingRequestedListener(startAppListening);
addSetDefaultSettingsListener(startAppListening);

View File

@@ -0,0 +1,46 @@
import type { Slice } from '@reduxjs/toolkit';
import type { UndoableOptions } from 'redux-undo';
import type { ZodType } from 'zod';
type StateFromSlice<T extends Slice> = T extends Slice<infer U> ? U : never;
export type SliceConfig<T extends Slice> = {
/**
* The redux slice (return of createSlice).
*/
slice: T;
/**
* The zod schema for the slice.
*/
schema: ZodType<StateFromSlice<T>>;
/**
* A function that returns the initial state of the slice.
*/
getInitialState: () => StateFromSlice<T>;
/**
* The optional persist configuration for this slice. If omitted, the slice will not be persisted.
*/
persistConfig?: {
/**
* Migrate the state to the current version during rehydration. This method should throw an error if the migration
* fails.
*
* @param state The rehydrated state.
* @returns A correctly-shaped state.
*/
migrate: (state: unknown) => StateFromSlice<T>;
/**
* Keys to omit from the persisted state.
*/
persistDenylist?: (keyof StateFromSlice<T>)[];
};
/**
* The optional undoable configuration for this slice. If omitted, the slice will not be undoable.
*/
undoableConfig?: {
/**
* The options to be passed into redux-undo.
*/
reduxUndoOptions: UndoableOptions<StateFromSlice<T>>;
};
};

View File

@@ -1,130 +1,299 @@
import type { FilterType } from 'features/controlLayers/store/filters'; import { zFilterType } from 'features/controlLayers/store/filters';
import type { ParameterPrecision, ParameterScheduler } from 'features/parameters/types/parameterSchemas'; import { zParameterPrecision, zParameterScheduler } from 'features/parameters/types/parameterSchemas';
import type { TabName } from 'features/ui/store/uiTypes'; import { zTabName } from 'features/ui/store/uiTypes';
import type { PartialDeep } from 'type-fest'; import type { PartialDeep } from 'type-fest';
import z from 'zod';
/** const zAppFeature = z.enum([
* A disable-able application feature 'faceRestore',
*/ 'upscaling',
export type AppFeature = 'lightbox',
| 'faceRestore' 'modelManager',
| 'upscaling' 'githubLink',
| 'lightbox' 'discordLink',
| 'modelManager' 'bugLink',
| 'githubLink' 'aboutModal',
| 'discordLink' 'localization',
| 'bugLink' 'consoleLogging',
| 'aboutModal' 'dynamicPrompting',
| 'localization' 'batches',
| 'consoleLogging' 'syncModels',
| 'dynamicPrompting' 'multiselect',
| 'batches' 'pauseQueue',
| 'syncModels' 'resumeQueue',
| 'multiselect' 'invocationCache',
| 'pauseQueue' 'modelCache',
| 'resumeQueue' 'bulkDownload',
| 'invocationCache' 'starterModels',
| 'modelCache' 'hfToken',
| 'bulkDownload' 'retryQueueItem',
| 'starterModels' 'cancelAndClearAll',
| 'hfToken' 'chatGPT4oHigh',
| 'retryQueueItem' 'modelRelationships',
| 'cancelAndClearAll' ]);
| 'chatGPT4oHigh' export type AppFeature = z.infer<typeof zAppFeature>;
| 'modelRelationships';
/**
* A disable-able Stable Diffusion feature
*/
export type SDFeature =
| 'controlNet'
| 'noise'
| 'perlinNoise'
| 'noiseThreshold'
| 'variation'
| 'symmetry'
| 'seamless'
| 'hires'
| 'lora'
| 'embedding'
| 'vae'
| 'hrf';
export type NumericalParameterConfig = { const zSDFeature = z.enum([
initial: number; 'controlNet',
sliderMin: number; 'noise',
sliderMax: number; 'perlinNoise',
numberInputMin: number; 'noiseThreshold',
numberInputMax: number; 'variation',
fineStep: number; 'symmetry',
coarseStep: number; 'seamless',
}; 'hires',
'lora',
'embedding',
'vae',
'hrf',
]);
export type SDFeature = z.infer<typeof zSDFeature>;
const zNumericalParameterConfig = z.object({
initial: z.number().default(512),
sliderMin: z.number().default(64),
sliderMax: z.number().default(1536),
numberInputMin: z.number().default(64),
numberInputMax: z.number().default(4096),
fineStep: z.number().default(8),
coarseStep: z.number().default(64),
});
/** /**
* Configuration options for the InvokeAI UI. * Configuration options for the InvokeAI UI.
* Distinct from system settings which may be changed inside the app. * Distinct from system settings which may be changed inside the app.
*/ */
export type AppConfig = { export const zAppConfig = z.object({
/** /**
* Whether or not we should update image urls when image loading errors * Whether or not we should update image urls when image loading errors
*/ */
shouldUpdateImagesOnConnect: boolean; shouldUpdateImagesOnConnect: z.boolean(),
shouldFetchMetadataFromApi: boolean; shouldFetchMetadataFromApi: z.boolean(),
/** /**
* Sets a size limit for outputs on the upscaling tab. This is a maximum dimension, so the actual max number of pixels * Sets a size limit for outputs on the upscaling tab. This is a maximum dimension, so the actual max number of pixels
* will be the square of this value. * will be the square of this value.
*/ */
maxUpscaleDimension?: number; maxUpscaleDimension: z.number().optional(),
allowPrivateBoards: boolean; allowPrivateBoards: z.boolean(),
allowPrivateStylePresets: boolean; allowPrivateStylePresets: z.boolean(),
allowClientSideUpload: boolean; allowClientSideUpload: z.boolean(),
allowPublishWorkflows: boolean; allowPublishWorkflows: z.boolean(),
allowPromptExpansion: boolean; allowPromptExpansion: z.boolean(),
disabledTabs: TabName[]; disabledTabs: z.array(zTabName),
disabledFeatures: AppFeature[]; disabledFeatures: z.array(zAppFeature),
disabledSDFeatures: SDFeature[]; disabledSDFeatures: z.array(zSDFeature),
nodesAllowlist: string[] | undefined; nodesAllowlist: z.array(z.string()).optional(),
nodesDenylist: string[] | undefined; nodesDenylist: z.array(z.string()).optional(),
metadataFetchDebounce?: number; metadataFetchDebounce: z.number().int().optional(),
workflowFetchDebounce?: number; workflowFetchDebounce: z.number().int().optional(),
isLocal?: boolean; isLocal: z.boolean().optional(),
shouldShowCredits: boolean; shouldShowCredits: z.boolean().optional(),
sd: { sd: z.object({
defaultModel?: string; defaultModel: z.string().optional(),
disabledControlNetModels: string[]; disabledControlNetModels: z.array(z.string()),
disabledControlNetProcessors: FilterType[]; disabledControlNetProcessors: z.array(zFilterType),
// Core parameters // Core parameters
iterations: NumericalParameterConfig; iterations: zNumericalParameterConfig,
width: NumericalParameterConfig; // initial value comes from model width: zNumericalParameterConfig,
height: NumericalParameterConfig; // initial value comes from model height: zNumericalParameterConfig,
steps: NumericalParameterConfig; steps: zNumericalParameterConfig,
guidance: NumericalParameterConfig; guidance: zNumericalParameterConfig,
cfgRescaleMultiplier: NumericalParameterConfig; cfgRescaleMultiplier: zNumericalParameterConfig,
img2imgStrength: NumericalParameterConfig; img2imgStrength: zNumericalParameterConfig,
scheduler?: ParameterScheduler; scheduler: zParameterScheduler.optional(),
vaePrecision?: ParameterPrecision; vaePrecision: zParameterPrecision.optional(),
// Canvas // Canvas
boundingBoxHeight: NumericalParameterConfig; // initial value comes from model boundingBoxHeight: zNumericalParameterConfig,
boundingBoxWidth: NumericalParameterConfig; // initial value comes from model boundingBoxWidth: zNumericalParameterConfig,
scaledBoundingBoxHeight: NumericalParameterConfig; // initial value comes from model scaledBoundingBoxHeight: zNumericalParameterConfig,
scaledBoundingBoxWidth: NumericalParameterConfig; // initial value comes from model scaledBoundingBoxWidth: zNumericalParameterConfig,
canvasCoherenceStrength: NumericalParameterConfig; canvasCoherenceStrength: zNumericalParameterConfig,
canvasCoherenceEdgeSize: NumericalParameterConfig; canvasCoherenceEdgeSize: zNumericalParameterConfig,
infillTileSize: NumericalParameterConfig; infillTileSize: zNumericalParameterConfig,
infillPatchmatchDownscaleSize: NumericalParameterConfig; infillPatchmatchDownscaleSize: zNumericalParameterConfig,
// Misc advanced // Misc advanced
clipSkip: NumericalParameterConfig; // slider and input max are ignored for this, because the values depend on the model clipSkip: zNumericalParameterConfig, // slider and input max are ignored for this, because the values depend on the model
maskBlur: NumericalParameterConfig; maskBlur: zNumericalParameterConfig,
hrfStrength: NumericalParameterConfig; hrfStrength: zNumericalParameterConfig,
dynamicPrompts: { dynamicPrompts: z.object({
maxPrompts: NumericalParameterConfig; maxPrompts: zNumericalParameterConfig,
}; }),
ca: { ca: z.object({
weight: NumericalParameterConfig; weight: zNumericalParameterConfig,
}; }),
}; }),
flux: { flux: z.object({
guidance: NumericalParameterConfig; guidance: zNumericalParameterConfig,
}; }),
}; });
export type AppConfig = z.infer<typeof zAppConfig>;
export type PartialAppConfig = PartialDeep<AppConfig>; export type PartialAppConfig = PartialDeep<AppConfig>;
export const getDefaultAppConfig = (): AppConfig => ({
isLocal: true,
shouldUpdateImagesOnConnect: false,
shouldFetchMetadataFromApi: false,
allowPrivateBoards: false,
allowPrivateStylePresets: false,
allowClientSideUpload: false,
allowPublishWorkflows: false,
allowPromptExpansion: false,
shouldShowCredits: false,
disabledTabs: [],
disabledFeatures: ['lightbox', 'faceRestore', 'batches'] satisfies AppFeature[],
disabledSDFeatures: ['variation', 'symmetry', 'hires', 'perlinNoise', 'noiseThreshold'] satisfies SDFeature[],
sd: {
disabledControlNetModels: [],
disabledControlNetProcessors: [],
iterations: {
initial: 1,
sliderMin: 1,
sliderMax: 1000,
numberInputMin: 1,
numberInputMax: 10000,
fineStep: 1,
coarseStep: 1,
},
width: zNumericalParameterConfig.parse({}), // initial value comes from model
height: zNumericalParameterConfig.parse({}), // initial value comes from model
boundingBoxWidth: zNumericalParameterConfig.parse({}), // initial value comes from model
boundingBoxHeight: zNumericalParameterConfig.parse({}), // initial value comes from model
scaledBoundingBoxWidth: zNumericalParameterConfig.parse({}), // initial value comes from model
scaledBoundingBoxHeight: zNumericalParameterConfig.parse({}), // initial value comes from model
scheduler: 'dpmpp_3m_k' as const,
vaePrecision: 'fp32' as const,
steps: {
initial: 30,
sliderMin: 1,
sliderMax: 100,
numberInputMin: 1,
numberInputMax: 500,
fineStep: 1,
coarseStep: 1,
},
guidance: {
initial: 7,
sliderMin: 1,
sliderMax: 20,
numberInputMin: 1,
numberInputMax: 200,
fineStep: 0.1,
coarseStep: 0.5,
},
img2imgStrength: {
initial: 0.7,
sliderMin: 0,
sliderMax: 1,
numberInputMin: 0,
numberInputMax: 1,
fineStep: 0.01,
coarseStep: 0.05,
},
canvasCoherenceStrength: {
initial: 0.3,
sliderMin: 0,
sliderMax: 1,
numberInputMin: 0,
numberInputMax: 1,
fineStep: 0.01,
coarseStep: 0.05,
},
hrfStrength: {
initial: 0.45,
sliderMin: 0,
sliderMax: 1,
numberInputMin: 0,
numberInputMax: 1,
fineStep: 0.01,
coarseStep: 0.05,
},
canvasCoherenceEdgeSize: {
initial: 16,
sliderMin: 0,
sliderMax: 128,
numberInputMin: 0,
numberInputMax: 1024,
fineStep: 8,
coarseStep: 16,
},
cfgRescaleMultiplier: {
initial: 0,
sliderMin: 0,
sliderMax: 0.99,
numberInputMin: 0,
numberInputMax: 0.99,
fineStep: 0.05,
coarseStep: 0.1,
},
clipSkip: {
initial: 0,
sliderMin: 0,
sliderMax: 12, // determined by model selection, unused in practice
numberInputMin: 0,
numberInputMax: 12, // determined by model selection, unused in practice
fineStep: 1,
coarseStep: 1,
},
infillPatchmatchDownscaleSize: {
initial: 1,
sliderMin: 1,
sliderMax: 10,
numberInputMin: 1,
numberInputMax: 10,
fineStep: 1,
coarseStep: 1,
},
infillTileSize: {
initial: 32,
sliderMin: 16,
sliderMax: 64,
numberInputMin: 16,
numberInputMax: 256,
fineStep: 1,
coarseStep: 1,
},
maskBlur: {
initial: 16,
sliderMin: 0,
sliderMax: 128,
numberInputMin: 0,
numberInputMax: 512,
fineStep: 1,
coarseStep: 1,
},
ca: {
weight: {
initial: 1,
sliderMin: 0,
sliderMax: 2,
numberInputMin: -1,
numberInputMax: 2,
fineStep: 0.01,
coarseStep: 0.05,
},
},
dynamicPrompts: {
maxPrompts: {
initial: 100,
sliderMin: 1,
sliderMax: 1000,
numberInputMin: 1,
numberInputMax: 10000,
fineStep: 1,
coarseStep: 10,
},
},
},
flux: {
guidance: {
initial: 4,
sliderMin: 2,
sliderMax: 6,
numberInputMin: 1,
numberInputMax: 20,
fineStep: 0.1,
coarseStep: 0.5,
},
},
});

View File

@@ -1,11 +0,0 @@
import { clearIdbKeyValStore } from 'app/store/enhancers/reduxRemember/driver';
import { useCallback } from 'react';
export const useClearStorage = () => {
const clearStorage = useCallback(() => {
clearIdbKeyValStore();
localStorage.clear();
}, []);
return clearStorage;
};

View File

@@ -1,6 +0,0 @@
import type { ChangeBoardModalState } from './types';
export const initialState: ChangeBoardModalState = {
isModalOpen: false,
image_names: [],
};

View File

@@ -1,12 +1,20 @@
import type { PayloadAction } from '@reduxjs/toolkit'; import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit';
import type { RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import z from 'zod';
import { initialState } from './initialState'; const zChangeBoardModalState = z.object({
isModalOpen: z.boolean().default(false),
image_names: z.array(z.string()).default(() => []),
});
type ChangeBoardModalState = z.infer<typeof zChangeBoardModalState>;
export const changeBoardModalSlice = createSlice({ const getInitialState = (): ChangeBoardModalState => zChangeBoardModalState.parse({});
const slice = createSlice({
name: 'changeBoardModal', name: 'changeBoardModal',
initialState, initialState: getInitialState(),
reducers: { reducers: {
isModalOpenChanged: (state, action: PayloadAction<boolean>) => { isModalOpenChanged: (state, action: PayloadAction<boolean>) => {
state.isModalOpen = action.payload; state.isModalOpen = action.payload;
@@ -21,6 +29,12 @@ export const changeBoardModalSlice = createSlice({
}, },
}); });
export const { isModalOpenChanged, imagesToChangeSelected, changeBoardReset } = changeBoardModalSlice.actions; export const { isModalOpenChanged, imagesToChangeSelected, changeBoardReset } = slice.actions;
export const selectChangeBoardModalSlice = (state: RootState) => state.changeBoardModal; export const selectChangeBoardModalSlice = (state: RootState) => state.changeBoardModal;
export const changeBoardModalSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zChangeBoardModalState,
getInitialState,
};

View File

@@ -1,4 +0,0 @@
export type ChangeBoardModalState = {
isModalOpen: boolean;
image_names: string[];
};

View File

@@ -1,7 +1,7 @@
import { $alt, $ctrl, $meta, $shift } from '@invoke-ai/ui-library'; import { $alt, $ctrl, $meta, $shift } from '@invoke-ai/ui-library';
import type { Selector } from '@reduxjs/toolkit'; import type { Selector } from '@reduxjs/toolkit';
import { addAppListener } from 'app/store/middleware/listenerMiddleware';
import type { AppStore, RootState } from 'app/store/store'; import type { AppStore, RootState } from 'app/store/store';
import { addAppListener } from 'app/store/store';
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer'; import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer';
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer'; import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';

View File

@@ -1,6 +1,7 @@
import type { PayloadAction, Selector } from '@reduxjs/toolkit'; import type { PayloadAction, Selector } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit'; import { createSelector, createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { zRgbaColor } from 'features/controlLayers/store/types'; import { zRgbaColor } from 'features/controlLayers/store/types';
import { z } from 'zod'; import { z } from 'zod';
@@ -11,32 +12,32 @@ const zCanvasSettingsState = z.object({
/** /**
* Whether to show HUD (Heads-Up Display) on the canvas. * Whether to show HUD (Heads-Up Display) on the canvas.
*/ */
showHUD: z.boolean().default(true), showHUD: z.boolean(),
/** /**
* Whether to clip lines and shapes to the generation bounding box. If disabled, lines and shapes will be clipped to * Whether to clip lines and shapes to the generation bounding box. If disabled, lines and shapes will be clipped to
* the canvas bounds. * the canvas bounds.
*/ */
clipToBbox: z.boolean().default(false), clipToBbox: z.boolean(),
/** /**
* Whether to show a dynamic grid on the canvas. If disabled, a checkerboard pattern will be shown instead. * Whether to show a dynamic grid on the canvas. If disabled, a checkerboard pattern will be shown instead.
*/ */
dynamicGrid: z.boolean().default(false), dynamicGrid: z.boolean(),
/** /**
* Whether to invert the scroll direction when adjusting the brush or eraser width with the scroll wheel. * Whether to invert the scroll direction when adjusting the brush or eraser width with the scroll wheel.
*/ */
invertScrollForToolWidth: z.boolean().default(false), invertScrollForToolWidth: z.boolean(),
/** /**
* The width of the brush tool. * The width of the brush tool.
*/ */
brushWidth: z.int().gt(0).default(50), brushWidth: z.int().gt(0),
/** /**
* The width of the eraser tool. * The width of the eraser tool.
*/ */
eraserWidth: z.int().gt(0).default(50), eraserWidth: z.int().gt(0),
/** /**
* The color to use when drawing lines or filling shapes. * The color to use when drawing lines or filling shapes.
*/ */
color: zRgbaColor.default({ r: 31, g: 160, b: 224, a: 1 }), // invokeBlue.500 color: zRgbaColor,
/** /**
* Whether to composite inpainted/outpainted regions back onto the source image when saving canvas generations. * Whether to composite inpainted/outpainted regions back onto the source image when saving canvas generations.
* *
@@ -44,57 +45,77 @@ const zCanvasSettingsState = z.object({
* *
* When `sendToCanvas` is disabled, this setting is ignored, masked regions will always be composited. * When `sendToCanvas` is disabled, this setting is ignored, masked regions will always be composited.
*/ */
outputOnlyMaskedRegions: z.boolean().default(true), outputOnlyMaskedRegions: z.boolean(),
/** /**
* Whether to automatically process the operations like filtering and auto-masking. * Whether to automatically process the operations like filtering and auto-masking.
*/ */
autoProcess: z.boolean().default(true), autoProcess: z.boolean(),
/** /**
* The snap-to-grid setting for the canvas. * The snap-to-grid setting for the canvas.
*/ */
snapToGrid: z.boolean().default(true), snapToGrid: z.boolean(),
/** /**
* Whether to show progress on the canvas when generating images. * Whether to show progress on the canvas when generating images.
*/ */
showProgressOnCanvas: z.boolean().default(true), showProgressOnCanvas: z.boolean(),
/** /**
* Whether to show the bounding box overlay on the canvas. * Whether to show the bounding box overlay on the canvas.
*/ */
bboxOverlay: z.boolean().default(false), bboxOverlay: z.boolean(),
/** /**
* Whether to preserve the masked region instead of inpainting it. * Whether to preserve the masked region instead of inpainting it.
*/ */
preserveMask: z.boolean().default(false), preserveMask: z.boolean(),
/** /**
* Whether to show only raster layers while staging. * Whether to show only raster layers while staging.
*/ */
isolatedStagingPreview: z.boolean().default(true), isolatedStagingPreview: z.boolean(),
/** /**
* Whether to show only the selected layer while filtering, transforming, or doing other operations. * Whether to show only the selected layer while filtering, transforming, or doing other operations.
*/ */
isolatedLayerPreview: z.boolean().default(true), isolatedLayerPreview: z.boolean(),
/** /**
* Whether to use pressure sensitivity for the brush and eraser tool when a pen device is used. * Whether to use pressure sensitivity for the brush and eraser tool when a pen device is used.
*/ */
pressureSensitivity: z.boolean().default(true), pressureSensitivity: z.boolean(),
/** /**
* Whether to show the rule of thirds composition guide overlay on the canvas. * Whether to show the rule of thirds composition guide overlay on the canvas.
*/ */
ruleOfThirds: z.boolean().default(false), ruleOfThirds: z.boolean(),
/** /**
* Whether to save all staging images to the gallery instead of keeping them as intermediate images. * Whether to save all staging images to the gallery instead of keeping them as intermediate images.
*/ */
saveAllImagesToGallery: z.boolean().default(false), saveAllImagesToGallery: z.boolean(),
/** /**
* The auto-switch mode for the canvas staging area. * The auto-switch mode for the canvas staging area.
*/ */
stagingAreaAutoSwitch: zAutoSwitchMode.default('switch_on_start'), stagingAreaAutoSwitch: zAutoSwitchMode,
}); });
type CanvasSettingsState = z.infer<typeof zCanvasSettingsState>; type CanvasSettingsState = z.infer<typeof zCanvasSettingsState>;
const getInitialState = () => zCanvasSettingsState.parse({}); const getInitialState = (): CanvasSettingsState => ({
showHUD: true,
clipToBbox: false,
dynamicGrid: false,
invertScrollForToolWidth: false,
brushWidth: 50,
eraserWidth: 50,
color: { r: 31, g: 160, b: 224, a: 1 }, // invokeBlue.500
outputOnlyMaskedRegions: true,
autoProcess: true,
snapToGrid: true,
showProgressOnCanvas: true,
bboxOverlay: false,
preserveMask: false,
isolatedStagingPreview: true,
isolatedLayerPreview: true,
pressureSensitivity: true,
ruleOfThirds: false,
saveAllImagesToGallery: false,
stagingAreaAutoSwitch: 'switch_on_start',
});
export const canvasSettingsSlice = createSlice({ const slice = createSlice({
name: 'canvasSettings', name: 'canvasSettings',
initialState: getInitialState(), initialState: getInitialState(),
reducers: { reducers: {
@@ -184,18 +205,15 @@ export const {
settingsRuleOfThirdsToggled, settingsRuleOfThirdsToggled,
settingsSaveAllImagesToGalleryToggled, settingsSaveAllImagesToGalleryToggled,
settingsStagingAreaAutoSwitchChanged, settingsStagingAreaAutoSwitchChanged,
} = canvasSettingsSlice.actions; } = slice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */ export const canvasSettingsSliceConfig: SliceConfig<typeof slice> = {
const migrate = (state: any): any => { slice,
return state; schema: zCanvasSettingsState,
}; getInitialState,
persistConfig: {
export const canvasSettingsPersistConfig: PersistConfig<CanvasSettingsState> = { migrate: (state) => zCanvasSettingsState.parse(state),
name: canvasSettingsSlice.name, },
initialState: getInitialState(),
migrate,
persistDenylist: [],
}; };
export const selectCanvasSettingsSlice = (s: RootState) => s.canvasSettings; export const selectCanvasSettingsSlice = (s: RootState) => s.canvasSettings;

View File

@@ -1,6 +1,6 @@
import type { PayloadAction, UnknownAction } from '@reduxjs/toolkit'; import type { PayloadAction, UnknownAction } from '@reduxjs/toolkit';
import { createSlice, isAnyOf } from '@reduxjs/toolkit'; import { createSlice, isAnyOf } from '@reduxjs/toolkit';
import type { PersistConfig } from 'app/store/store'; import type { SliceConfig } from 'app/store/types';
import { moveOneToEnd, moveOneToStart, moveToEnd, moveToStart } from 'common/util/arrayUtils'; import { moveOneToEnd, moveOneToStart, moveToEnd, moveToStart } from 'common/util/arrayUtils';
import { deepClone } from 'common/util/deepClone'; import { deepClone } from 'common/util/deepClone';
import { roundDownToMultiple, roundToMultiple } from 'common/util/roundDownToMultiple'; import { roundDownToMultiple, roundToMultiple } from 'common/util/roundDownToMultiple';
@@ -80,6 +80,7 @@ import {
isFLUXReduxConfig, isFLUXReduxConfig,
isImagenAspectRatioID, isImagenAspectRatioID,
isIPAdapterConfig, isIPAdapterConfig,
zCanvasState,
} from './types'; } from './types';
import { import {
converters, converters,
@@ -95,7 +96,7 @@ import {
initialT2IAdapter, initialT2IAdapter,
} from './util'; } from './util';
export const canvasSlice = createSlice({ const slice = createSlice({
name: 'canvas', name: 'canvas',
initialState: getInitialCanvasState(), initialState: getInitialCanvasState(),
reducers: { reducers: {
@@ -1675,19 +1676,7 @@ export const {
inpaintMaskDenoiseLimitChanged, inpaintMaskDenoiseLimitChanged,
inpaintMaskDenoiseLimitDeleted, inpaintMaskDenoiseLimitDeleted,
// inpaintMaskRecalled, // inpaintMaskRecalled,
} = canvasSlice.actions; } = slice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrate = (state: any): any => {
return state;
};
export const canvasPersistConfig: PersistConfig<CanvasState> = {
name: canvasSlice.name,
initialState: getInitialCanvasState(),
migrate,
persistDenylist: [],
};
const syncScaledSize = (state: CanvasState) => { const syncScaledSize = (state: CanvasState) => {
if (API_BASE_MODELS.includes(state.bbox.modelBase)) { if (API_BASE_MODELS.includes(state.bbox.modelBase)) {
@@ -1710,14 +1699,14 @@ const syncScaledSize = (state: CanvasState) => {
let filter = true; let filter = true;
export const canvasUndoableConfig: UndoableOptions<CanvasState, UnknownAction> = { const canvasUndoableConfig: UndoableOptions<CanvasState, UnknownAction> = {
limit: 64, limit: 64,
undoType: canvasUndo.type, undoType: canvasUndo.type,
redoType: canvasRedo.type, redoType: canvasRedo.type,
clearHistoryType: canvasClearHistory.type, clearHistoryType: canvasClearHistory.type,
filter: (action, _state, _history) => { filter: (action, _state, _history) => {
// Ignore all actions from other slices // Ignore all actions from other slices
if (!action.type.startsWith(canvasSlice.name)) { if (!action.type.startsWith(slice.name)) {
return false; return false;
} }
// Throttle rapid actions of the same type // Throttle rapid actions of the same type
@@ -1728,6 +1717,18 @@ export const canvasUndoableConfig: UndoableOptions<CanvasState, UnknownAction> =
// debug: import.meta.env.MODE === 'development', // debug: import.meta.env.MODE === 'development',
}; };
export const canvasSliceConfig: SliceConfig<typeof slice> = {
slice,
getInitialState: getInitialCanvasState,
schema: zCanvasState,
persistConfig: {
migrate: (state) => zCanvasState.parse(state),
},
undoableConfig: {
reduxUndoOptions: canvasUndoableConfig,
},
};
const doNotGroupMatcher = isAnyOf(entityBrushLineAdded, entityEraserLineAdded, entityRectAdded); const doNotGroupMatcher = isAnyOf(entityBrushLineAdded, entityEraserLineAdded, entityRectAdded);
// Store rapid actions of the same type at most once every x time. // Store rapid actions of the same type at most once every x time.

View File

@@ -1,27 +1,29 @@
import { createSelector, createSlice, type PayloadAction } from '@reduxjs/toolkit'; import { createSelector, createSlice, type PayloadAction } from '@reduxjs/toolkit';
import { EMPTY_ARRAY } from 'app/store/constants'; import { EMPTY_ARRAY } from 'app/store/constants';
import type { PersistConfig, RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { deepClone } from 'common/util/deepClone'; import type { SliceConfig } from 'app/store/types';
import { isPlainObject } from 'es-toolkit';
import { getPrefixedId } from 'features/controlLayers/konva/util'; import { getPrefixedId } from 'features/controlLayers/konva/util';
import { useMemo } from 'react'; import { useMemo } from 'react';
import { queueApi } from 'services/api/endpoints/queue'; import { queueApi } from 'services/api/endpoints/queue';
import { assert } from 'tsafe';
import z from 'zod';
type CanvasStagingAreaState = { const zCanvasStagingAreaState = z.object({
_version: 1; _version: z.literal(1),
canvasSessionId: string; canvasSessionId: z.string(),
canvasDiscardedQueueItems: number[]; canvasDiscardedQueueItems: z.array(z.number().int()),
}; });
type CanvasStagingAreaState = z.infer<typeof zCanvasStagingAreaState>;
const INITIAL_STATE: CanvasStagingAreaState = { const getInitialState = (): CanvasStagingAreaState => ({
_version: 1, _version: 1,
canvasSessionId: getPrefixedId('canvas'), canvasSessionId: getPrefixedId('canvas'),
canvasDiscardedQueueItems: [], canvasDiscardedQueueItems: [],
}; });
const getInitialState = (): CanvasStagingAreaState => deepClone(INITIAL_STATE); const slice = createSlice({
export const canvasSessionSlice = createSlice({
name: 'canvasSession', name: 'canvasSession',
initialState: getInitialState(), initialState: getInitialState(),
reducers: { reducers: {
@@ -48,26 +50,26 @@ export const canvasSessionSlice = createSlice({
}, },
}); });
export const { canvasSessionReset, canvasQueueItemDiscarded } = canvasSessionSlice.actions; export const { canvasSessionReset, canvasQueueItemDiscarded } = slice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */ export const canvasSessionSliceConfig: SliceConfig<typeof slice> = {
const migrate = (state: any): any => { slice,
if (!('_version' in state)) { schema: zCanvasStagingAreaState,
state._version = 1; getInitialState,
state.canvasSessionId = state.canvasSessionId ?? getPrefixedId('canvas'); persistConfig: {
} migrate: (state) => {
assert(isPlainObject(state));
if (!('_version' in state)) {
state._version = 1;
state.canvasSessionId = state.canvasSessionId ?? getPrefixedId('canvas');
}
return state; return zCanvasStagingAreaState.parse(state);
},
},
}; };
export const canvasStagingAreaPersistConfig: PersistConfig<CanvasStagingAreaState> = { export const selectCanvasSessionSlice = (s: RootState) => s[slice.name];
name: canvasSessionSlice.name,
initialState: getInitialState(),
migrate,
persistDenylist: [],
};
export const selectCanvasSessionSlice = (s: RootState) => s[canvasSessionSlice.name];
export const selectCanvasSessionId = createSelector(selectCanvasSessionSlice, ({ canvasSessionId }) => canvasSessionId); export const selectCanvasSessionId = createSelector(selectCanvasSessionSlice, ({ canvasSessionId }) => canvasSessionId);
const selectDiscardedItems = createSelector( const selectDiscardedItems = createSelector(

View File

@@ -166,7 +166,7 @@ const _zFilterConfig = z.discriminatedUnion('type', [
]); ]);
export type FilterConfig = z.infer<typeof _zFilterConfig>; export type FilterConfig = z.infer<typeof _zFilterConfig>;
const zFilterType = z.enum([ export const zFilterType = z.enum([
'adjust_image', 'adjust_image',
'canny_edge_detection', 'canny_edge_detection',
'color_map', 'color_map',

View File

@@ -1,30 +1,32 @@
import { createSelector, createSlice, type PayloadAction } from '@reduxjs/toolkit'; import { createSelector, createSlice, type PayloadAction } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import { deepClone } from 'common/util/deepClone'; import type { SliceConfig } from 'app/store/types';
import { paramsReset } from 'features/controlLayers/store/paramsSlice'; import { paramsReset } from 'features/controlLayers/store/paramsSlice';
import type { LoRA } from 'features/controlLayers/store/types'; import { type LoRA, zLoRA } from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common'; import { zModelIdentifierField } from 'features/nodes/types/common';
import type { LoRAModelConfig } from 'services/api/types'; import type { LoRAModelConfig } from 'services/api/types';
import { v4 as uuidv4 } from 'uuid'; import { v4 as uuidv4 } from 'uuid';
import z from 'zod';
type LoRAsState = { const zLoRAsState = z.object({
loras: LoRA[]; loras: z.array(zLoRA),
}; });
type LoRAsState = z.infer<typeof zLoRAsState>;
const defaultLoRAConfig: Pick<LoRA, 'weight' | 'isEnabled'> = { const defaultLoRAConfig: Pick<LoRA, 'weight' | 'isEnabled'> = {
weight: 0.75, weight: 0.75,
isEnabled: true, isEnabled: true,
}; };
const initialState: LoRAsState = { const getInitialState = (): LoRAsState => ({
loras: [], loras: [],
}; });
const selectLoRA = (state: LoRAsState, id: string) => state.loras.find((lora) => lora.id === id); const selectLoRA = (state: LoRAsState, id: string) => state.loras.find((lora) => lora.id === id);
export const lorasSlice = createSlice({ const slice = createSlice({
name: 'loras', name: 'loras',
initialState, initialState: getInitialState(),
reducers: { reducers: {
loraAdded: { loraAdded: {
reducer: (state, action: PayloadAction<{ model: LoRAModelConfig; id: string }>) => { reducer: (state, action: PayloadAction<{ model: LoRAModelConfig; id: string }>) => {
@@ -66,24 +68,21 @@ export const lorasSlice = createSlice({
extraReducers(builder) { extraReducers(builder) {
builder.addCase(paramsReset, () => { builder.addCase(paramsReset, () => {
// When a new session is requested, clear all LoRAs // When a new session is requested, clear all LoRAs
return deepClone(initialState); return getInitialState();
}); });
}, },
}); });
export const { loraAdded, loraRecalled, loraDeleted, loraWeightChanged, loraIsEnabledChanged, loraAllDeleted } = export const { loraAdded, loraRecalled, loraDeleted, loraWeightChanged, loraIsEnabledChanged, loraAllDeleted } =
lorasSlice.actions; slice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */ export const lorasSliceConfig: SliceConfig<typeof slice> = {
const migrate = (state: any): any => { slice,
return state; schema: zLoRAsState,
}; getInitialState,
persistConfig: {
export const lorasPersistConfig: PersistConfig<LoRAsState> = { migrate: (state) => zLoRAsState.parse(state),
name: lorasSlice.name, },
initialState,
migrate,
persistDenylist: [],
}; };
export const selectLoRAsSlice = (state: RootState) => state.loras; export const selectLoRAsSlice = (state: RootState) => state.loras;

View File

@@ -1,6 +1,7 @@
import type { PayloadAction, Selector } from '@reduxjs/toolkit'; import type { PayloadAction, Selector } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit'; import { createSelector, createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { deepClone } from 'common/util/deepClone'; import { deepClone } from 'common/util/deepClone';
import { roundDownToMultiple, roundToMultiple } from 'common/util/roundDownToMultiple'; import { roundDownToMultiple, roundToMultiple } from 'common/util/roundDownToMultiple';
import { clamp } from 'es-toolkit/compat'; import { clamp } from 'es-toolkit/compat';
@@ -15,6 +16,7 @@ import {
isChatGPT4oAspectRatioID, isChatGPT4oAspectRatioID,
isFluxKontextAspectRatioID, isFluxKontextAspectRatioID,
isImagenAspectRatioID, isImagenAspectRatioID,
zParamsState,
} from 'features/controlLayers/store/types'; } from 'features/controlLayers/store/types';
import { calculateNewSize } from 'features/controlLayers/util/getScaledBoundingBoxDimensions'; import { calculateNewSize } from 'features/controlLayers/util/getScaledBoundingBoxDimensions';
import { CLIP_SKIP_MAP } from 'features/parameters/types/constants'; import { CLIP_SKIP_MAP } from 'features/parameters/types/constants';
@@ -40,7 +42,7 @@ import { getGridSize, getIsSizeOptimal, getOptimalDimension } from 'features/par
import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/api/endpoints/models'; import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/api/endpoints/models';
import { isNonRefinerMainModelConfig } from 'services/api/types'; import { isNonRefinerMainModelConfig } from 'services/api/types';
export const paramsSlice = createSlice({ const slice = createSlice({
name: 'params', name: 'params',
initialState: getInitialParamsState(), initialState: getInitialParamsState(),
reducers: { reducers: {
@@ -92,7 +94,12 @@ export const paramsSlice = createSlice({
state, state,
action: PayloadAction<{ model: ParameterModel | null; previousModel?: ParameterModel | null }> action: PayloadAction<{ model: ParameterModel | null; previousModel?: ParameterModel | null }>
) => { ) => {
const { model, previousModel } = action.payload; const { previousModel } = action.payload;
const result = zParamsState.shape.model.safeParse(action.payload.model);
if (!result.success) {
return;
}
const model = result.data;
state.model = model; state.model = model;
// If the model base changes (e.g. SD1.5 -> SDXL), we need to change a few things // If the model base changes (e.g. SD1.5 -> SDXL), we need to change a few things
@@ -111,25 +118,53 @@ export const paramsSlice = createSlice({
}, },
vaeSelected: (state, action: PayloadAction<ParameterVAEModel | null>) => { vaeSelected: (state, action: PayloadAction<ParameterVAEModel | null>) => {
// null is a valid VAE! // null is a valid VAE!
state.vae = action.payload; const result = zParamsState.shape.vae.safeParse(action.payload);
if (!result.success) {
return;
}
state.vae = result.data;
}, },
fluxVAESelected: (state, action: PayloadAction<ParameterVAEModel | null>) => { fluxVAESelected: (state, action: PayloadAction<ParameterVAEModel | null>) => {
state.fluxVAE = action.payload; const result = zParamsState.shape.fluxVAE.safeParse(action.payload);
if (!result.success) {
return;
}
state.fluxVAE = result.data;
}, },
t5EncoderModelSelected: (state, action: PayloadAction<ParameterT5EncoderModel | null>) => { t5EncoderModelSelected: (state, action: PayloadAction<ParameterT5EncoderModel | null>) => {
state.t5EncoderModel = action.payload; const result = zParamsState.shape.t5EncoderModel.safeParse(action.payload);
if (!result.success) {
return;
}
state.t5EncoderModel = result.data;
}, },
controlLoRAModelSelected: (state, action: PayloadAction<ParameterControlLoRAModel | null>) => { controlLoRAModelSelected: (state, action: PayloadAction<ParameterControlLoRAModel | null>) => {
state.controlLora = action.payload; const result = zParamsState.shape.controlLora.safeParse(action.payload);
if (!result.success) {
return;
}
state.controlLora = result.data;
}, },
clipEmbedModelSelected: (state, action: PayloadAction<ParameterCLIPEmbedModel | null>) => { clipEmbedModelSelected: (state, action: PayloadAction<ParameterCLIPEmbedModel | null>) => {
state.clipEmbedModel = action.payload; const result = zParamsState.shape.clipEmbedModel.safeParse(action.payload);
if (!result.success) {
return;
}
state.clipEmbedModel = result.data;
}, },
clipLEmbedModelSelected: (state, action: PayloadAction<ParameterCLIPLEmbedModel | null>) => { clipLEmbedModelSelected: (state, action: PayloadAction<ParameterCLIPLEmbedModel | null>) => {
state.clipLEmbedModel = action.payload; const result = zParamsState.shape.clipLEmbedModel.safeParse(action.payload);
if (!result.success) {
return;
}
state.clipLEmbedModel = result.data;
}, },
clipGEmbedModelSelected: (state, action: PayloadAction<ParameterCLIPGEmbedModel | null>) => { clipGEmbedModelSelected: (state, action: PayloadAction<ParameterCLIPGEmbedModel | null>) => {
state.clipGEmbedModel = action.payload; const result = zParamsState.shape.clipGEmbedModel.safeParse(action.payload);
if (!result.success) {
return;
}
state.clipGEmbedModel = result.data;
}, },
vaePrecisionChanged: (state, action: PayloadAction<ParameterPrecision>) => { vaePrecisionChanged: (state, action: PayloadAction<ParameterPrecision>) => {
state.vaePrecision = action.payload; state.vaePrecision = action.payload;
@@ -156,7 +191,11 @@ export const paramsSlice = createSlice({
state.shouldConcatPrompts = action.payload; state.shouldConcatPrompts = action.payload;
}, },
refinerModelChanged: (state, action: PayloadAction<ParameterSDXLRefinerModel | null>) => { refinerModelChanged: (state, action: PayloadAction<ParameterSDXLRefinerModel | null>) => {
state.refinerModel = action.payload; const result = zParamsState.shape.refinerModel.safeParse(action.payload);
if (!result.success) {
return;
}
state.refinerModel = result.data;
}, },
setRefinerSteps: (state, action: PayloadAction<number>) => { setRefinerSteps: (state, action: PayloadAction<number>) => {
state.refinerSteps = action.payload; state.refinerSteps = action.payload;
@@ -397,18 +436,15 @@ export const {
syncedToOptimalDimension, syncedToOptimalDimension,
paramsReset, paramsReset,
} = paramsSlice.actions; } = slice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */ export const paramsSliceConfig: SliceConfig<typeof slice> = {
const migrate = (state: any): any => { slice,
return state; schema: zParamsState,
}; getInitialState: getInitialParamsState,
persistConfig: {
export const paramsPersistConfig: PersistConfig<ParamsState> = { migrate: (state) => zParamsState.parse(state),
name: paramsSlice.name, },
initialState: getInitialParamsState(),
migrate,
persistDenylist: [],
}; };
export const selectParamsSlice = (state: RootState) => state.params; export const selectParamsSlice = (state: RootState) => state.params;

View File

@@ -2,7 +2,8 @@ import { objectEquals } from '@observ33r/object-equals';
import type { PayloadAction } from '@reduxjs/toolkit'; import type { PayloadAction } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit'; import { createSelector, createSlice } from '@reduxjs/toolkit';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import type { PersistConfig, RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { clamp } from 'es-toolkit/compat'; import { clamp } from 'es-toolkit/compat';
import { getPrefixedId } from 'features/controlLayers/konva/util'; import { getPrefixedId } from 'features/controlLayers/konva/util';
import type { FLUXReduxImageInfluence, RefImagesState } from 'features/controlLayers/store/types'; import type { FLUXReduxImageInfluence, RefImagesState } from 'features/controlLayers/store/types';
@@ -18,7 +19,7 @@ import { assert } from 'tsafe';
import type { PartialDeep } from 'type-fest'; import type { PartialDeep } from 'type-fest';
import type { CLIPVisionModelV2, IPMethodV2, RefImageState } from './types'; import type { CLIPVisionModelV2, IPMethodV2, RefImageState } from './types';
import { getInitialRefImagesState, isFLUXReduxConfig, isIPAdapterConfig } from './types'; import { getInitialRefImagesState, isFLUXReduxConfig, isIPAdapterConfig, zRefImagesState } from './types';
import { import {
getReferenceImageState, getReferenceImageState,
imageDTOToImageWithDims, imageDTOToImageWithDims,
@@ -36,7 +37,7 @@ type PayloadActionWithId<T = void> = T extends void
} & T } & T
>; >;
export const refImagesSlice = createSlice({ const slice = createSlice({
name: 'refImages', name: 'refImages',
initialState: getInitialRefImagesState(), initialState: getInitialRefImagesState(),
reducers: { reducers: {
@@ -263,18 +264,16 @@ export const {
refImageFLUXReduxImageInfluenceChanged, refImageFLUXReduxImageInfluenceChanged,
refImageIsEnabledToggled, refImageIsEnabledToggled,
refImagesRecalled, refImagesRecalled,
} = refImagesSlice.actions; } = slice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */ export const refImagesSliceConfig: SliceConfig<typeof slice> = {
const migrate = (state: any): any => { slice,
return state; schema: zRefImagesState,
}; getInitialState: getInitialRefImagesState,
persistConfig: {
export const refImagesPersistConfig: PersistConfig<RefImagesState> = { migrate: (state) => zRefImagesState.parse(state),
name: refImagesSlice.name, persistDenylist: ['selectedEntityId', 'isPanelOpen'],
initialState: getInitialRefImagesState(), },
migrate,
persistDenylist: ['selectedEntityId', 'isPanelOpen'],
}; };
export const selectRefImagesSlice = (state: RootState) => state.refImages; export const selectRefImagesSlice = (state: RootState) => state.refImages;

View File

@@ -1,9 +1,7 @@
import { deepClone } from 'common/util/deepClone'; import { deepClone } from 'common/util/deepClone';
import type { CanvasEntityAdapter } from 'features/controlLayers/konva/CanvasEntity/types'; import type { CanvasEntityAdapter } from 'features/controlLayers/konva/CanvasEntity/types';
import { fetchModelConfigByIdentifier } from 'features/metadata/util/modelFetchingHelpers';
import type { ProgressImage } from 'features/nodes/types/common'; import type { ProgressImage } from 'features/nodes/types/common';
import { zMainModelBase, zModelIdentifierField } from 'features/nodes/types/common'; import { zMainModelBase, zModelIdentifierField } from 'features/nodes/types/common';
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
import { import {
zParameterCanvasCoherenceMode, zParameterCanvasCoherenceMode,
zParameterCFGRescaleMultiplier, zParameterCFGRescaleMultiplier,
@@ -29,33 +27,17 @@ import {
zParameterT5EncoderModel, zParameterT5EncoderModel,
zParameterVAEModel, zParameterVAEModel,
} from 'features/parameters/types/parameterSchemas'; } from 'features/parameters/types/parameterSchemas';
import { getImageDTOSafe } from 'services/api/endpoints/images';
import type { JsonObject } from 'type-fest'; import type { JsonObject } from 'type-fest';
import { z } from 'zod'; import { z } from 'zod';
const zId = z.string().min(1); const zId = z.string().min(1);
const zName = z.string().min(1).nullable(); const zName = z.string().min(1).nullable();
const zServerValidatedModelIdentifierField = zModelIdentifierField.refine(async (modelIdentifier) => { export const zImageWithDims = z.object({
try { image_name: z.string(),
await fetchModelConfigByIdentifier(modelIdentifier); width: z.number().int().positive(),
return true; height: z.number().int().positive(),
} catch {
return false;
}
}); });
const zImageWithDims = z
.object({
image_name: z.string(),
width: z.number().int().positive(),
height: z.number().int().positive(),
})
.refine(async (v) => {
const { image_name } = v;
const imageDTO = await getImageDTOSafe(image_name, { forceRefetch: true });
return imageDTO !== null;
});
export type ImageWithDims = z.infer<typeof zImageWithDims>; export type ImageWithDims = z.infer<typeof zImageWithDims>;
const zImageWithDimsDataURL = z.object({ const zImageWithDimsDataURL = z.object({
@@ -253,7 +235,7 @@ export type CanvasObjectState = z.infer<typeof zCanvasObjectState>;
const zIPAdapterConfig = z.object({ const zIPAdapterConfig = z.object({
type: z.literal('ip_adapter'), type: z.literal('ip_adapter'),
image: zImageWithDims.nullable(), image: zImageWithDims.nullable(),
model: zServerValidatedModelIdentifierField.nullable(), model: zModelIdentifierField.nullable(),
weight: z.number().gte(-1).lte(2), weight: z.number().gte(-1).lte(2),
beginEndStepPct: zBeginEndStepPct, beginEndStepPct: zBeginEndStepPct,
method: zIPMethodV2, method: zIPMethodV2,
@@ -268,7 +250,7 @@ export type FLUXReduxImageInfluence = z.infer<typeof zFLUXReduxImageInfluence>;
const zFLUXReduxConfig = z.object({ const zFLUXReduxConfig = z.object({
type: z.literal('flux_redux'), type: z.literal('flux_redux'),
image: zImageWithDims.nullable(), image: zImageWithDims.nullable(),
model: zServerValidatedModelIdentifierField.nullable(), model: zModelIdentifierField.nullable(),
imageInfluence: zFLUXReduxImageInfluence.default('highest'), imageInfluence: zFLUXReduxImageInfluence.default('highest'),
}); });
export type FLUXReduxConfig = z.infer<typeof zFLUXReduxConfig>; export type FLUXReduxConfig = z.infer<typeof zFLUXReduxConfig>;
@@ -281,14 +263,14 @@ const zChatGPT4oReferenceImageConfig = z.object({
* But we use a model drop down to switch between different ref image types, so there needs to be a model here else * But we use a model drop down to switch between different ref image types, so there needs to be a model here else
* there will be no way to switch between ref image types. * there will be no way to switch between ref image types.
*/ */
model: zServerValidatedModelIdentifierField.nullable(), model: zModelIdentifierField.nullable(),
}); });
export type ChatGPT4oReferenceImageConfig = z.infer<typeof zChatGPT4oReferenceImageConfig>; export type ChatGPT4oReferenceImageConfig = z.infer<typeof zChatGPT4oReferenceImageConfig>;
const zFluxKontextReferenceImageConfig = z.object({ const zFluxKontextReferenceImageConfig = z.object({
type: z.literal('flux_kontext_reference_image'), type: z.literal('flux_kontext_reference_image'),
image: zImageWithDims.nullable(), image: zImageWithDims.nullable(),
model: zServerValidatedModelIdentifierField.nullable(), model: zModelIdentifierField.nullable(),
}); });
export type FluxKontextReferenceImageConfig = z.infer<typeof zFluxKontextReferenceImageConfig>; export type FluxKontextReferenceImageConfig = z.infer<typeof zFluxKontextReferenceImageConfig>;
@@ -360,7 +342,7 @@ export type CanvasInpaintMaskState = z.infer<typeof zCanvasInpaintMaskState>;
const zControlNetConfig = z.object({ const zControlNetConfig = z.object({
type: z.literal('controlnet'), type: z.literal('controlnet'),
model: zServerValidatedModelIdentifierField.nullable(), model: zModelIdentifierField.nullable(),
weight: z.number().gte(-1).lte(2), weight: z.number().gte(-1).lte(2),
beginEndStepPct: zBeginEndStepPct, beginEndStepPct: zBeginEndStepPct,
controlMode: zControlModeV2, controlMode: zControlModeV2,
@@ -369,7 +351,7 @@ export type ControlNetConfig = z.infer<typeof zControlNetConfig>;
const zT2IAdapterConfig = z.object({ const zT2IAdapterConfig = z.object({
type: z.literal('t2i_adapter'), type: z.literal('t2i_adapter'),
model: zServerValidatedModelIdentifierField.nullable(), model: zModelIdentifierField.nullable(),
weight: z.number().gte(-1).lte(2), weight: z.number().gte(-1).lte(2),
beginEndStepPct: zBeginEndStepPct, beginEndStepPct: zBeginEndStepPct,
}); });
@@ -378,7 +360,7 @@ export type T2IAdapterConfig = z.infer<typeof zT2IAdapterConfig>;
const zControlLoRAConfig = z.object({ const zControlLoRAConfig = z.object({
type: z.literal('control_lora'), type: z.literal('control_lora'),
weight: z.number().gte(-1).lte(2), weight: z.number().gte(-1).lte(2),
model: zServerValidatedModelIdentifierField.nullable(), model: zModelIdentifierField.nullable(),
}); });
export type ControlLoRAConfig = z.infer<typeof zControlLoRAConfig>; export type ControlLoRAConfig = z.infer<typeof zControlLoRAConfig>;
@@ -424,12 +406,13 @@ export const zCanvasEntityIdentifer = z.object({
}); });
export type CanvasEntityIdentifier<T extends CanvasEntityType = CanvasEntityType> = { id: string; type: T }; export type CanvasEntityIdentifier<T extends CanvasEntityType = CanvasEntityType> = { id: string; type: T };
export type LoRA = { export const zLoRA = z.object({
id: string; id: z.string(),
isEnabled: boolean; isEnabled: z.boolean(),
model: ParameterLoRAModel; model: zModelIdentifierField,
weight: number; weight: z.number().gte(-1).lte(2),
}; });
export type LoRA = z.infer<typeof zLoRA>;
export type EphemeralProgressImage = { sessionId: string; image: ProgressImage }; export type EphemeralProgressImage = { sessionId: string; image: ProgressImage };
@@ -522,62 +505,108 @@ const zDimensionsState = z.object({
aspectRatio: zAspectRatioConfig, aspectRatio: zAspectRatioConfig,
}); });
const zParamsState = z.object({ export const zParamsState = z.object({
maskBlur: z.number().default(16), maskBlur: z.number(),
maskBlurMethod: zParameterMaskBlurMethod.default('box'), maskBlurMethod: zParameterMaskBlurMethod,
canvasCoherenceMode: zParameterCanvasCoherenceMode.default('Gaussian Blur'), canvasCoherenceMode: zParameterCanvasCoherenceMode,
canvasCoherenceMinDenoise: zParameterStrength.default(0), canvasCoherenceMinDenoise: zParameterStrength,
canvasCoherenceEdgeSize: z.number().default(16), canvasCoherenceEdgeSize: z.number(),
infillMethod: z.string().default('lama'), infillMethod: z.string(),
infillTileSize: z.number().default(32), infillTileSize: z.number(),
infillPatchmatchDownscaleSize: z.number().default(1), infillPatchmatchDownscaleSize: z.number(),
infillColorValue: zRgbaColor.default({ r: 0, g: 0, b: 0, a: 1 }), infillColorValue: zRgbaColor,
cfgScale: zParameterCFGScale.default(7.5), cfgScale: zParameterCFGScale,
cfgRescaleMultiplier: zParameterCFGRescaleMultiplier.default(0), cfgRescaleMultiplier: zParameterCFGRescaleMultiplier,
guidance: zParameterGuidance.default(4), guidance: zParameterGuidance,
img2imgStrength: zParameterStrength.default(0.75), img2imgStrength: zParameterStrength,
optimizedDenoisingEnabled: z.boolean().default(true), optimizedDenoisingEnabled: z.boolean(),
iterations: z.number().default(1), iterations: z.number(),
scheduler: zParameterScheduler.default('dpmpp_3m_k'), scheduler: zParameterScheduler,
upscaleScheduler: zParameterScheduler.default('kdpm_2'), upscaleScheduler: zParameterScheduler,
upscaleCfgScale: zParameterCFGScale.default(2), upscaleCfgScale: zParameterCFGScale,
seed: zParameterSeed.default(0), seed: zParameterSeed,
shouldRandomizeSeed: z.boolean().default(true), shouldRandomizeSeed: z.boolean(),
steps: zParameterSteps.default(30), steps: zParameterSteps,
model: zParameterModel.nullable().default(null), model: zParameterModel.nullable(),
vae: zParameterVAEModel.nullable().default(null), vae: zParameterVAEModel.nullable(),
vaePrecision: zParameterPrecision.default('fp32'), vaePrecision: zParameterPrecision,
fluxVAE: zParameterVAEModel.nullable().default(null), fluxVAE: zParameterVAEModel.nullable(),
seamlessXAxis: z.boolean().default(false), seamlessXAxis: z.boolean(),
seamlessYAxis: z.boolean().default(false), seamlessYAxis: z.boolean(),
clipSkip: z.number().default(0), clipSkip: z.number(),
shouldUseCpuNoise: z.boolean().default(true), shouldUseCpuNoise: z.boolean(),
positivePrompt: zParameterPositivePrompt.default(''), positivePrompt: zParameterPositivePrompt,
// Negative prompt may be disabled, in which case it will be null negativePrompt: zParameterNegativePrompt,
negativePrompt: zParameterNegativePrompt.default(null), positivePrompt2: zParameterPositiveStylePromptSDXL,
positivePrompt2: zParameterPositiveStylePromptSDXL.default(''), negativePrompt2: zParameterNegativeStylePromptSDXL,
negativePrompt2: zParameterNegativeStylePromptSDXL.default(''), shouldConcatPrompts: z.boolean(),
shouldConcatPrompts: z.boolean().default(true), refinerModel: zParameterSDXLRefinerModel.nullable(),
refinerModel: zParameterSDXLRefinerModel.nullable().default(null), refinerSteps: z.number(),
refinerSteps: z.number().default(20), refinerCFGScale: z.number(),
refinerCFGScale: z.number().default(7.5), refinerScheduler: zParameterScheduler,
refinerScheduler: zParameterScheduler.default('euler'), refinerPositiveAestheticScore: z.number(),
refinerPositiveAestheticScore: z.number().default(6), refinerNegativeAestheticScore: z.number(),
refinerNegativeAestheticScore: z.number().default(2.5), refinerStart: z.number(),
refinerStart: z.number().default(0.8), t5EncoderModel: zParameterT5EncoderModel.nullable(),
t5EncoderModel: zParameterT5EncoderModel.nullable().default(null), clipEmbedModel: zParameterCLIPEmbedModel.nullable(),
clipEmbedModel: zParameterCLIPEmbedModel.nullable().default(null), clipLEmbedModel: zParameterCLIPLEmbedModel.nullable(),
clipLEmbedModel: zParameterCLIPLEmbedModel.nullable().default(null), clipGEmbedModel: zParameterCLIPGEmbedModel.nullable(),
clipGEmbedModel: zParameterCLIPGEmbedModel.nullable().default(null), controlLora: zParameterControlLoRAModel.nullable(),
controlLora: zParameterControlLoRAModel.nullable().default(null), dimensions: zDimensionsState,
dimensions: zDimensionsState.default({
rect: { x: 0, y: 0, width: 512, height: 512 },
aspectRatio: DEFAULT_ASPECT_RATIO_CONFIG,
}),
}); });
export type ParamsState = z.infer<typeof zParamsState>; export type ParamsState = z.infer<typeof zParamsState>;
const INITIAL_PARAMS_STATE = zParamsState.parse({}); export const getInitialParamsState = (): ParamsState => ({
export const getInitialParamsState = () => deepClone(INITIAL_PARAMS_STATE); maskBlur: 16,
maskBlurMethod: 'box',
canvasCoherenceMode: 'Gaussian Blur',
canvasCoherenceMinDenoise: 0,
canvasCoherenceEdgeSize: 16,
infillMethod: 'lama',
infillTileSize: 32,
infillPatchmatchDownscaleSize: 1,
infillColorValue: { r: 0, g: 0, b: 0, a: 1 },
cfgScale: 7.5,
cfgRescaleMultiplier: 0,
guidance: 4,
img2imgStrength: 0.75,
optimizedDenoisingEnabled: true,
iterations: 1,
scheduler: 'dpmpp_3m_k',
upscaleScheduler: 'kdpm_2',
upscaleCfgScale: 2,
seed: 0,
shouldRandomizeSeed: true,
steps: 30,
model: null,
vae: null,
vaePrecision: 'fp32',
fluxVAE: null,
seamlessXAxis: false,
seamlessYAxis: false,
clipSkip: 0,
shouldUseCpuNoise: true,
positivePrompt: '',
negativePrompt: null,
positivePrompt2: '',
negativePrompt2: '',
shouldConcatPrompts: true,
refinerModel: null,
refinerSteps: 20,
refinerCFGScale: 7.5,
refinerScheduler: 'euler',
refinerPositiveAestheticScore: 6,
refinerNegativeAestheticScore: 2.5,
refinerStart: 0.8,
t5EncoderModel: null,
clipEmbedModel: null,
clipLEmbedModel: null,
clipGEmbedModel: null,
controlLora: null,
dimensions: {
rect: { x: 0, y: 0, width: 512, height: 512 },
aspectRatio: deepClone(DEFAULT_ASPECT_RATIO_CONFIG),
},
});
const zInpaintMasks = z.object({ const zInpaintMasks = z.object({
isHidden: z.boolean(), isHidden: z.boolean(),
@@ -595,38 +624,45 @@ const zRegionalGuidance = z.object({
isHidden: z.boolean(), isHidden: z.boolean(),
entities: z.array(zCanvasRegionalGuidanceState), entities: z.array(zCanvasRegionalGuidanceState),
}); });
const zCanvasState = z.object({ export const zCanvasState = z.object({
_version: z.literal(3).default(3), _version: z.literal(3),
selectedEntityIdentifier: zCanvasEntityIdentifer.nullable().default(null), selectedEntityIdentifier: zCanvasEntityIdentifer.nullable(),
bookmarkedEntityIdentifier: zCanvasEntityIdentifer.nullable().default(null), bookmarkedEntityIdentifier: zCanvasEntityIdentifer.nullable(),
inpaintMasks: zInpaintMasks.default({ isHidden: false, entities: [] }), inpaintMasks: zInpaintMasks,
rasterLayers: zRasterLayers.default({ isHidden: false, entities: [] }), rasterLayers: zRasterLayers,
controlLayers: zControlLayers.default({ isHidden: false, entities: [] }), controlLayers: zControlLayers,
regionalGuidance: zRegionalGuidance.default({ isHidden: false, entities: [] }), regionalGuidance: zRegionalGuidance,
bbox: zBboxState.default({ bbox: zBboxState,
});
export type CanvasState = z.infer<typeof zCanvasState>;
export const getInitialCanvasState = (): CanvasState => ({
_version: 3,
selectedEntityIdentifier: null,
bookmarkedEntityIdentifier: null,
inpaintMasks: { isHidden: false, entities: [] },
rasterLayers: { isHidden: false, entities: [] },
controlLayers: { isHidden: false, entities: [] },
regionalGuidance: { isHidden: false, entities: [] },
bbox: {
rect: { x: 0, y: 0, width: 512, height: 512 }, rect: { x: 0, y: 0, width: 512, height: 512 },
aspectRatio: DEFAULT_ASPECT_RATIO_CONFIG, aspectRatio: deepClone(DEFAULT_ASPECT_RATIO_CONFIG),
scaleMethod: 'auto', scaleMethod: 'auto',
scaledSize: { width: 512, height: 512 }, scaledSize: { width: 512, height: 512 },
modelBase: 'sd-1', modelBase: 'sd-1',
}), },
}); });
export type CanvasState = z.infer<typeof zCanvasState>;
const zRefImagesState = z.object({ export const zRefImagesState = z.object({
selectedEntityId: z.string().nullable().default(null), selectedEntityId: z.string().nullable(),
isPanelOpen: z.boolean().default(false), isPanelOpen: z.boolean(),
entities: z.array(zRefImageState).default(() => []), entities: z.array(zRefImageState),
}); });
export type RefImagesState = z.infer<typeof zRefImagesState>; export type RefImagesState = z.infer<typeof zRefImagesState>;
const INITIAL_REF_IMAGES_STATE = zRefImagesState.parse({}); export const getInitialRefImagesState = (): RefImagesState => ({
export const getInitialRefImagesState = () => deepClone(INITIAL_REF_IMAGES_STATE); selectedEntityId: null,
isPanelOpen: false,
/** entities: [],
* Gets a fresh canvas initial state with no references in memory to existing objects. });
*/
const CANVAS_INITIAL_STATE = zCanvasState.parse({});
export const getInitialCanvasState = () => deepClone(CANVAS_INITIAL_STATE);
export const zCanvasReferenceImageState_OLD = zCanvasEntityBase.extend({ export const zCanvasReferenceImageState_OLD = zCanvasEntityBase.extend({
type: z.literal('reference_image'), type: z.literal('reference_image'),

View File

@@ -1,25 +1,29 @@
import type { PayloadAction, Selector } from '@reduxjs/toolkit'; import type { PayloadAction, Selector } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit'; import { createSelector, createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { buildZodTypeGuard } from 'common/util/zodUtils'; import { buildZodTypeGuard } from 'common/util/zodUtils';
import { isPlainObject } from 'es-toolkit';
import { assert } from 'tsafe';
import { z } from 'zod'; import { z } from 'zod';
const zSeedBehaviour = z.enum(['PER_ITERATION', 'PER_PROMPT']); const zSeedBehaviour = z.enum(['PER_ITERATION', 'PER_PROMPT']);
export const isSeedBehaviour = buildZodTypeGuard(zSeedBehaviour); export const isSeedBehaviour = buildZodTypeGuard(zSeedBehaviour);
export type SeedBehaviour = z.infer<typeof zSeedBehaviour>; export type SeedBehaviour = z.infer<typeof zSeedBehaviour>;
export interface DynamicPromptsState { const zDynamicPromptsState = z.object({
_version: 1; _version: z.literal(1),
maxPrompts: number; maxPrompts: z.number().int().min(1).max(1000),
combinatorial: boolean; combinatorial: z.boolean(),
prompts: string[]; prompts: z.array(z.string()),
parsingError: string | undefined | null; parsingError: z.string().nullish(),
isError: boolean; isError: z.boolean(),
isLoading: boolean; isLoading: z.boolean(),
seedBehaviour: SeedBehaviour; seedBehaviour: zSeedBehaviour,
} });
export type DynamicPromptsState = z.infer<typeof zDynamicPromptsState>;
const initialDynamicPromptsState: DynamicPromptsState = { const getInitialState = (): DynamicPromptsState => ({
_version: 1, _version: 1,
maxPrompts: 100, maxPrompts: 100,
combinatorial: true, combinatorial: true,
@@ -28,11 +32,11 @@ const initialDynamicPromptsState: DynamicPromptsState = {
isError: false, isError: false,
isLoading: false, isLoading: false,
seedBehaviour: 'PER_ITERATION', seedBehaviour: 'PER_ITERATION',
}; });
export const dynamicPromptsSlice = createSlice({ const slice = createSlice({
name: 'dynamicPrompts', name: 'dynamicPrompts',
initialState: initialDynamicPromptsState, initialState: getInitialState(),
reducers: { reducers: {
maxPromptsChanged: (state, action: PayloadAction<number>) => { maxPromptsChanged: (state, action: PayloadAction<number>) => {
state.maxPrompts = action.payload; state.maxPrompts = action.payload;
@@ -63,21 +67,22 @@ export const {
isErrorChanged, isErrorChanged,
isLoadingChanged, isLoadingChanged,
seedBehaviourChanged, seedBehaviourChanged,
} = dynamicPromptsSlice.actions; } = slice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */ export const dynamicPromptsSliceConfig: SliceConfig<typeof slice> = {
const migrateDynamicPromptsState = (state: any): any => { slice,
if (!('_version' in state)) { schema: zDynamicPromptsState,
state._version = 1; getInitialState,
} persistConfig: {
return state; migrate: (state) => {
}; assert(isPlainObject(state));
if (!('_version' in state)) {
export const dynamicPromptsPersistConfig: PersistConfig<DynamicPromptsState> = { state._version = 1;
name: dynamicPromptsSlice.name, }
initialState: initialDynamicPromptsState, return zDynamicPromptsState.parse(state);
migrate: migrateDynamicPromptsState, },
persistDenylist: ['prompts'], persistDenylist: ['prompts', 'parsingError', 'isError', 'isLoading'],
},
}; };
export const selectDynamicPromptsSlice = (state: RootState) => state.dynamicPrompts; export const selectDynamicPromptsSlice = (state: RootState) => state.dynamicPrompts;

View File

@@ -1,5 +1,6 @@
import { MenuItem } from '@invoke-ai/ui-library'; import { MenuItem } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import { imageDTOToImageWithDims } from 'features/controlLayers/store/util';
import { useImageDTOContext } from 'features/gallery/contexts/ImageDTOContext'; import { useImageDTOContext } from 'features/gallery/contexts/ImageDTOContext';
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice'; import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
import { toast } from 'features/toast/toast'; import { toast } from 'features/toast/toast';
@@ -14,7 +15,7 @@ export const ImageMenuItemSendToUpscale = memo(() => {
const imageDTO = useImageDTOContext(); const imageDTO = useImageDTOContext();
const handleSendToCanvas = useCallback(() => { const handleSendToCanvas = useCallback(() => {
dispatch(upscaleInitialImageChanged(imageDTO)); dispatch(upscaleInitialImageChanged(imageDTOToImageWithDims(imageDTO)));
navigationApi.switchToTab('upscaling'); navigationApi.switchToTab('upscaling');
toast({ toast({
id: 'SENT_TO_CANVAS', id: 'SENT_TO_CANVAS',

View File

@@ -1,13 +1,23 @@
import { objectEquals } from '@observ33r/object-equals'; import { objectEquals } from '@observ33r/object-equals';
import type { PayloadAction } from '@reduxjs/toolkit'; import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { isPlainObject } from 'es-toolkit';
import { uniq } from 'es-toolkit/compat'; import { uniq } from 'es-toolkit/compat';
import type { BoardRecordOrderBy } from 'services/api/types'; import type { BoardRecordOrderBy } from 'services/api/types';
import { assert } from 'tsafe';
import type { BoardId, ComparisonMode, GalleryState, GalleryView, OrderDir } from './types'; import {
type BoardId,
type ComparisonMode,
type GalleryState,
type GalleryView,
type OrderDir,
zGalleryState,
} from './types';
const initialGalleryState: GalleryState = { const getInitialState = (): GalleryState => ({
selection: [], selection: [],
shouldAutoSwitch: true, shouldAutoSwitch: true,
autoAssignBoardOnClick: true, autoAssignBoardOnClick: true,
@@ -26,11 +36,11 @@ const initialGalleryState: GalleryState = {
shouldShowArchivedBoards: false, shouldShowArchivedBoards: false,
boardsListOrderBy: 'created_at', boardsListOrderBy: 'created_at',
boardsListOrderDir: 'DESC', boardsListOrderDir: 'DESC',
}; });
export const gallerySlice = createSlice({ const slice = createSlice({
name: 'gallery', name: 'gallery',
initialState: initialGalleryState, initialState: getInitialState(),
reducers: { reducers: {
imageSelected: (state, action: PayloadAction<string | null>) => { imageSelected: (state, action: PayloadAction<string | null>) => {
// Let's be efficient here and not update the selection unless it has actually changed. This helps to prevent // Let's be efficient here and not update the selection unless it has actually changed. This helps to prevent
@@ -187,21 +197,22 @@ export const {
searchTermChanged, searchTermChanged,
boardsListOrderByChanged, boardsListOrderByChanged,
boardsListOrderDirChanged, boardsListOrderDirChanged,
} = gallerySlice.actions; } = slice.actions;
export const selectGallerySlice = (state: RootState) => state.gallery; export const selectGallerySlice = (state: RootState) => state.gallery;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */ export const gallerySliceConfig: SliceConfig<typeof slice> = {
const migrateGalleryState = (state: any): any => { slice,
if (!('_version' in state)) { schema: zGalleryState,
state._version = 1; getInitialState,
} persistConfig: {
return state; migrate: (state) => {
}; assert(isPlainObject(state));
if (!('_version' in state)) {
export const galleryPersistConfig: PersistConfig<GalleryState> = { state._version = 1;
name: gallerySlice.name, }
initialState: initialGalleryState, return zGalleryState.parse(state);
migrate: migrateGalleryState, },
persistDenylist: ['selection', 'selectedBoardId', 'galleryView', 'imageToCompare'], persistDenylist: ['selection', 'selectedBoardId', 'galleryView', 'imageToCompare'],
},
}; };

View File

@@ -0,0 +1,13 @@
import type { S } from 'services/api/types';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
import { describe, test } from 'vitest';
import type { BoardRecordOrderBy } from './types';
describe('Gallery Types', () => {
// Ensure zod types match OpenAPI types
test('BoardRecordOrderBy', () => {
assert<Equals<BoardRecordOrderBy, S['BoardRecordOrderBy']>>();
});
});

View File

@@ -1,31 +1,41 @@
import type { BoardRecordOrderBy, ImageCategory } from 'services/api/types'; import type { ImageCategory } from 'services/api/types';
import z from 'zod';
const zGalleryView = z.enum(['images', 'assets']);
export type GalleryView = z.infer<typeof zGalleryView>;
const zBoardId = z.union([z.literal('none'), z.intersection(z.string(), z.record(z.never(), z.never()))]);
export type BoardId = z.infer<typeof zBoardId>;
const zComparisonMode = z.enum(['slider', 'side-by-side', 'hover']);
export type ComparisonMode = z.infer<typeof zComparisonMode>;
const zComparisonFit = z.enum(['contain', 'fill']);
export type ComparisonFit = z.infer<typeof zComparisonFit>;
const zOrderDir = z.enum(['ASC', 'DESC']);
export type OrderDir = z.infer<typeof zOrderDir>;
const zBoardRecordOrderBy = z.enum(['created_at', 'board_name']);
export type BoardRecordOrderBy = z.infer<typeof zBoardRecordOrderBy>;
export const IMAGE_CATEGORIES: ImageCategory[] = ['general']; export const IMAGE_CATEGORIES: ImageCategory[] = ['general'];
export const ASSETS_CATEGORIES: ImageCategory[] = ['control', 'mask', 'user', 'other']; export const ASSETS_CATEGORIES: ImageCategory[] = ['control', 'mask', 'user', 'other'];
export type GalleryView = 'images' | 'assets'; export const zGalleryState = z.object({
export type BoardId = 'none' | (string & Record<never, never>); selection: z.array(z.string()),
export type ComparisonMode = 'slider' | 'side-by-side' | 'hover'; shouldAutoSwitch: z.boolean(),
export type ComparisonFit = 'contain' | 'fill'; autoAssignBoardOnClick: z.boolean(),
export type OrderDir = 'ASC' | 'DESC'; autoAddBoardId: zBoardId,
galleryImageMinimumWidth: z.number(),
selectedBoardId: zBoardId,
galleryView: zGalleryView,
boardSearchText: z.string(),
starredFirst: z.boolean(),
orderDir: zOrderDir,
searchTerm: z.string(),
alwaysShowImageSizeBadge: z.boolean(),
imageToCompare: z.string().nullable(),
comparisonMode: zComparisonMode,
comparisonFit: zComparisonFit,
shouldShowArchivedBoards: z.boolean(),
boardsListOrderBy: zBoardRecordOrderBy,
boardsListOrderDir: zOrderDir,
});
export type GalleryState = { export type GalleryState = z.infer<typeof zGalleryState>;
selection: string[];
shouldAutoSwitch: boolean;
autoAssignBoardOnClick: boolean;
autoAddBoardId: BoardId;
galleryImageMinimumWidth: number;
selectedBoardId: BoardId;
galleryView: GalleryView;
boardSearchText: string;
starredFirst: boolean;
orderDir: OrderDir;
searchTerm: string;
alwaysShowImageSizeBadge: boolean;
imageToCompare: string | null;
comparisonMode: ComparisonMode;
comparisonFit: ComparisonFit;
shouldShowArchivedBoards: boolean;
boardsListOrderBy: BoardRecordOrderBy;
boardsListOrderDir: OrderDir;
};

View File

@@ -58,7 +58,7 @@ export const setRegionalGuidanceReferenceImage = (arg: {
export const setUpscaleInitialImage = (arg: { imageDTO: ImageDTO; dispatch: AppDispatch }) => { export const setUpscaleInitialImage = (arg: { imageDTO: ImageDTO; dispatch: AppDispatch }) => {
const { imageDTO, dispatch } = arg; const { imageDTO, dispatch } = arg;
dispatch(upscaleInitialImageChanged(imageDTO)); dispatch(upscaleInitialImageChanged(imageDTOToImageWithDims(imageDTO)));
}; };
export const setNodeImageFieldImage = (arg: { export const setNodeImageFieldImage = (arg: {

View File

@@ -89,6 +89,7 @@ import { t } from 'i18next';
import type { ComponentType } from 'react'; import type { ComponentType } from 'react';
import { useCallback, useEffect, useState } from 'react'; import { useCallback, useEffect, useState } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { imagesApi } from 'services/api/endpoints/images';
import { modelsApi } from 'services/api/endpoints/models'; import { modelsApi } from 'services/api/endpoints/models';
import type { AnyModelConfig, ModelType } from 'services/api/types'; import type { AnyModelConfig, ModelType } from 'services/api/types';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
@@ -787,11 +788,55 @@ const LoRAs: CollectionMetadataHandler<LoRA[]> = {
const CanvasLayers: SingleMetadataHandler<CanvasMetadata> = { const CanvasLayers: SingleMetadataHandler<CanvasMetadata> = {
[SingleMetadataKey]: true, [SingleMetadataKey]: true,
type: 'CanvasLayers', type: 'CanvasLayers',
parse: async (metadata) => { parse: async (metadata, store) => {
const raw = getProperty(metadata, 'canvas_v2_metadata'); const raw = getProperty(metadata, 'canvas_v2_metadata');
// This validator fetches all referenced images. If any do not exist, validation fails. The logic for this is in // This validator fetches all referenced images. If any do not exist, validation fails. The logic for this is in
// the zImageWithDims schema. // the zImageWithDims schema.
const parsed = await zCanvasMetadata.parseAsync(raw); const parsed = await zCanvasMetadata.parseAsync(raw);
for (const entity of parsed.controlLayers) {
if (entity.controlAdapter.model) {
await throwIfModelDoesNotExist(entity.controlAdapter.model.key, store);
}
for (const object of entity.objects) {
if (object.type === 'image' && 'image_name' in object.image) {
await throwIfImageDoesNotExist(object.image.image_name, store);
}
}
}
for (const entity of parsed.inpaintMasks) {
for (const object of entity.objects) {
if (object.type === 'image' && 'image_name' in object.image) {
await throwIfImageDoesNotExist(object.image.image_name, store);
}
}
}
for (const entity of parsed.rasterLayers) {
for (const object of entity.objects) {
if (object.type === 'image' && 'image_name' in object.image) {
await throwIfImageDoesNotExist(object.image.image_name, store);
}
}
}
for (const entity of parsed.regionalGuidance) {
for (const object of entity.objects) {
if (object.type === 'image' && 'image_name' in object.image) {
await throwIfImageDoesNotExist(object.image.image_name, store);
}
}
for (const refImage of entity.referenceImages) {
if (refImage.config.image) {
await throwIfImageDoesNotExist(refImage.config.image.image_name, store);
}
if (refImage.config.model) {
await throwIfModelDoesNotExist(refImage.config.model.key, store);
}
}
}
return Promise.resolve(parsed); return Promise.resolve(parsed);
}, },
recall: (value, store) => { recall: (value, store) => {
@@ -824,27 +869,39 @@ const CanvasLayers: SingleMetadataHandler<CanvasMetadata> = {
const RefImages: CollectionMetadataHandler<RefImageState[]> = { const RefImages: CollectionMetadataHandler<RefImageState[]> = {
[CollectionMetadataKey]: true, [CollectionMetadataKey]: true,
type: 'RefImages', type: 'RefImages',
parse: async (metadata) => { parse: async (metadata, store) => {
let parsed: RefImageState[] | null = null;
try { try {
// First attempt to parse from the v6 slot // First attempt to parse from the v6 slot
const raw = getProperty(metadata, 'ref_images'); const raw = getProperty(metadata, 'ref_images');
// This validator fetches all referenced images. If any do not exist, validation fails. The logic for this is in parsed = z.array(zRefImageState).parse(raw);
// the zImageWithDims schema.
const parsed = await z.array(zRefImageState).parseAsync(raw);
return Promise.resolve(parsed);
} catch { } catch {
// Fall back to extracting from canvas metadata] // Fall back to extracting from canvas metadata]
const raw = getProperty(metadata, 'canvas_v2_metadata.referenceImages.entities'); const raw = getProperty(metadata, 'canvas_v2_metadata.referenceImages.entities');
// This validator fetches all referenced images. If any do not exist, validation fails. The logic for this is in // This validator fetches all referenced images. If any do not exist, validation fails. The logic for this is in
// the zImageWithDims schema. // the zImageWithDims schema.
const oldParsed = await z.array(zCanvasReferenceImageState_OLD).parseAsync(raw); const oldParsed = await z.array(zCanvasReferenceImageState_OLD).parseAsync(raw);
const parsed: RefImageState[] = oldParsed.map(({ id, ipAdapter, isEnabled }) => ({ parsed = oldParsed.map(({ id, ipAdapter, isEnabled }) => ({
id, id,
config: ipAdapter, config: ipAdapter,
isEnabled, isEnabled,
})); }));
return parsed;
} }
if (!parsed) {
throw new Error('No valid reference images found in metadata');
}
for (const refImage of parsed) {
if (refImage.config.image) {
await throwIfImageDoesNotExist(refImage.config.image.image_name, store);
}
if (refImage.config.model) {
await throwIfModelDoesNotExist(refImage.config.model.key, store);
}
}
return parsed;
}, },
recall: (value, store) => { recall: (value, store) => {
const entities = value.map((data) => ({ ...data, id: getPrefixedId('reference_image') })); const entities = value.map((data) => ({ ...data, id: getPrefixedId('reference_image') }));
@@ -1241,3 +1298,19 @@ const isCompatibleWithMainModel = (candidate: ModelIdentifierField, store: AppSt
} }
return candidate.base === base; return candidate.base === base;
}; };
const throwIfImageDoesNotExist = async (name: string, store: AppStore): Promise<void> => {
try {
await store.dispatch(imagesApi.endpoints.getImageDTO.initiate(name, { subscribe: false })).unwrap();
} catch {
throw new Error(`Image with name ${name} does not exist`);
}
};
const throwIfModelDoesNotExist = async (key: string, store: AppStore): Promise<void> => {
try {
await store.dispatch(modelsApi.endpoints.getModelConfig.initiate(key, { subscribe: false }));
} catch {
throw new Error(`Model with key ${key} does not exist`);
}
};

View File

@@ -1,7 +1,6 @@
import { getStore } from 'app/store/nanostores/store'; import { getStore } from 'app/store/nanostores/store';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import { modelsApi } from 'services/api/endpoints/models'; import { modelsApi } from 'services/api/endpoints/models';
import type { AnyModelConfig, BaseModelType, ModelType } from 'services/api/types'; import type { AnyModelConfig } from 'services/api/types';
/** /**
* Raised when a model config is unable to be fetched. * Raised when a model config is unable to be fetched.
@@ -47,45 +46,6 @@ const fetchModelConfig = async (key: string): Promise<AnyModelConfig> => {
} }
}; };
/**
* Fetches the model config for a given model name, base model, and model type. This provides backwards compatibility
* for MM1 model identifiers.
* @param name The model name.
* @param base The base model.
* @param type The model type.
* @returns A promise that resolves to the model config.
* @throws {ModelConfigNotFoundError} If the model config is unable to be fetched.
*/
const fetchModelConfigByAttrs = async (name: string, base: BaseModelType, type: ModelType): Promise<AnyModelConfig> => {
const { dispatch } = getStore();
try {
const req = dispatch(
modelsApi.endpoints.getModelConfigByAttrs.initiate({ name, base, type }, { subscribe: false })
);
return await req.unwrap();
} catch {
throw new ModelConfigNotFoundError(`Unable to retrieve model config for name/base/type ${name}/${base}/${type}`);
}
};
/**
* Fetches the model config given an identifier. First attempts to fetch by key, then falls back to fetching by attrs.
* @param identifier The model identifier.
* @returns A promise that resolves to the model config.
* @throws {ModelConfigNotFoundError} If the model config is unable to be fetched.
*/
export const fetchModelConfigByIdentifier = async (identifier: ModelIdentifierField): Promise<AnyModelConfig> => {
try {
return await fetchModelConfig(identifier.key);
} catch {
try {
return await fetchModelConfigByAttrs(identifier.name, identifier.base, identifier.type);
} catch {
throw new ModelConfigNotFoundError(`Unable to retrieve model config for identifier ${identifier}`);
}
}
};
/** /**
* Fetches the model config for a given model key and type, and ensures that the model config is of a specific type. * Fetches the model config for a given model key and type, and ensures that the model config is of a specific type.
* @param key The model key. * @param key The model key.

View File

@@ -1,21 +1,28 @@
import type { PayloadAction } from '@reduxjs/toolkit'; import type { PayloadAction } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit'; import { createSelector, createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import type { ModelType } from 'services/api/types'; import type { SliceConfig } from 'app/store/types';
import { isPlainObject } from 'es-toolkit';
import { zModelType } from 'features/nodes/types/common';
import { assert } from 'tsafe';
import z from 'zod';
export type FilterableModelType = Exclude<ModelType, 'onnx'> | 'refiner'; const zFilterableModelType = zModelType.exclude(['onnx']).or(z.literal('refiner'));
export type FilterableModelType = z.infer<typeof zFilterableModelType>;
type ModelManagerState = { const zModelManagerState = z.object({
_version: 1; _version: z.literal(1),
selectedModelKey: string | null; selectedModelKey: z.string().nullable(),
selectedModelMode: 'edit' | 'view'; selectedModelMode: z.enum(['edit', 'view']),
searchTerm: string; searchTerm: z.string(),
filteredModelType: FilterableModelType | null; filteredModelType: zFilterableModelType.nullable(),
scanPath: string | undefined; scanPath: z.string().optional(),
shouldInstallInPlace: boolean; shouldInstallInPlace: z.boolean(),
}; });
const initialModelManagerState: ModelManagerState = { type ModelManagerState = z.infer<typeof zModelManagerState>;
const getInitialState = (): ModelManagerState => ({
_version: 1, _version: 1,
selectedModelKey: null, selectedModelKey: null,
selectedModelMode: 'view', selectedModelMode: 'view',
@@ -23,11 +30,11 @@ const initialModelManagerState: ModelManagerState = {
searchTerm: '', searchTerm: '',
scanPath: undefined, scanPath: undefined,
shouldInstallInPlace: true, shouldInstallInPlace: true,
}; });
export const modelManagerV2Slice = createSlice({ const slice = createSlice({
name: 'modelmanagerV2', name: 'modelmanagerV2',
initialState: initialModelManagerState, initialState: getInitialState(),
reducers: { reducers: {
setSelectedModelKey: (state, action: PayloadAction<string | null>) => { setSelectedModelKey: (state, action: PayloadAction<string | null>) => {
state.selectedModelMode = 'view'; state.selectedModelMode = 'view';
@@ -58,21 +65,22 @@ export const {
setSelectedModelMode, setSelectedModelMode,
setScanPath, setScanPath,
shouldInstallInPlaceChanged, shouldInstallInPlaceChanged,
} = modelManagerV2Slice.actions; } = slice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */ export const modelManagerSliceConfig: SliceConfig<typeof slice> = {
const migrateModelManagerState = (state: any): any => { slice,
if (!('_version' in state)) { schema: zModelManagerState,
state._version = 1; getInitialState,
} persistConfig: {
return state; migrate: (state) => {
}; assert(isPlainObject(state));
if (!('_version' in state)) {
export const modelManagerV2PersistConfig: PersistConfig<ModelManagerState> = { state._version = 1;
name: modelManagerV2Slice.name, }
initialState: initialModelManagerState, return zModelManagerState.parse(state);
migrate: migrateModelManagerState, },
persistDenylist: ['selectedModelKey', 'selectedModelMode', 'filteredModelType', 'searchTerm'], persistDenylist: ['selectedModelKey', 'selectedModelMode', 'filteredModelType', 'searchTerm'],
},
}; };
export const selectModelManagerV2Slice = (state: RootState) => state.modelmanagerV2; export const selectModelManagerV2Slice = (state: RootState) => state.modelmanagerV2;

View File

@@ -14,7 +14,13 @@ import type {
ReactFlowProps, ReactFlowProps,
ReactFlowState, ReactFlowState,
} from '@xyflow/react'; } from '@xyflow/react';
import { Background, ReactFlow, useStore as useReactFlowStore, useUpdateNodeInternals } from '@xyflow/react'; import {
Background,
ReactFlow,
SelectionMode,
useStore as useReactFlowStore,
useUpdateNodeInternals,
} from '@xyflow/react';
import { useAppDispatch, useAppSelector, useAppStore } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector, useAppStore } from 'app/store/storeHooks';
import { useFocusRegion, useIsRegionFocused } from 'common/hooks/focus'; import { useFocusRegion, useIsRegionFocused } from 'common/hooks/focus';
import { $isSelectingOutputNode, $outputNodeId } from 'features/nodes/components/sidePanel/workflow/publish'; import { $isSelectingOutputNode, $outputNodeId } from 'features/nodes/components/sidePanel/workflow/publish';
@@ -256,7 +262,7 @@ export const Flow = memo(() => {
style={flowStyles} style={flowStyles}
onPaneClick={handlePaneClick} onPaneClick={handlePaneClick}
deleteKeyCode={null} deleteKeyCode={null}
selectionMode={selectionMode} selectionMode={selectionMode === 'full' ? SelectionMode.Full : SelectionMode.Partial}
elevateEdgesOnSelect elevateEdgesOnSelect
nodeDragThreshold={1} nodeDragThreshold={1}
noDragClassName={NO_DRAG_CLASS} noDragClassName={NO_DRAG_CLASS}

View File

@@ -11,14 +11,15 @@ import type {
XYPosition, XYPosition,
} from '@xyflow/react'; } from '@xyflow/react';
import { applyEdgeChanges, applyNodeChanges, getConnectedEdges, getIncomers, getOutgoers } from '@xyflow/react'; import { applyEdgeChanges, applyNodeChanges, getConnectedEdges, getIncomers, getOutgoers } from '@xyflow/react';
import type { PersistConfig } from 'app/store/store'; import type { SliceConfig } from 'app/store/types';
import { deepClone } from 'common/util/deepClone'; import { deepClone } from 'common/util/deepClone';
import { isPlainObject } from 'es-toolkit';
import { import {
addElement, addElement,
removeElement, removeElement,
reparentElement, reparentElement,
} from 'features/nodes/components/sidePanel/builder/form-manipulation'; } from 'features/nodes/components/sidePanel/builder/form-manipulation';
import type { NodesState } from 'features/nodes/store/types'; import { type NodesState, zNodesState } from 'features/nodes/store/types';
import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants'; import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants';
import type { import type {
BoardFieldValue, BoardFieldValue,
@@ -127,6 +128,7 @@ import {
import { atom, computed } from 'nanostores'; import { atom, computed } from 'nanostores';
import type { MouseEvent } from 'react'; import type { MouseEvent } from 'react';
import type { UndoableOptions } from 'redux-undo'; import type { UndoableOptions } from 'redux-undo';
import { assert } from 'tsafe';
import type { z } from 'zod'; import type { z } from 'zod';
import type { PendingConnection, Templates } from './types'; import type { PendingConnection, Templates } from './types';
@@ -151,11 +153,11 @@ export const getInitialWorkflow = (): Omit<NodesState, 'mode' | 'formFieldInitia
}; };
}; };
const initialState: NodesState = { const getInitialState = (): NodesState => ({
_version: 1, _version: 1,
formFieldInitialValues: {}, formFieldInitialValues: {},
...getInitialWorkflow(), ...getInitialWorkflow(),
}; });
type FieldValueAction<T extends FieldValue> = PayloadAction<{ type FieldValueAction<T extends FieldValue> = PayloadAction<{
nodeId: string; nodeId: string;
@@ -208,9 +210,9 @@ const fieldValueReducer = <T extends FieldValue>(
field.value = result.data; field.value = result.data;
}; };
export const nodesSlice = createSlice({ const slice = createSlice({
name: 'nodes', name: 'nodes',
initialState: initialState, initialState: getInitialState(),
reducers: { reducers: {
nodesChanged: (state, action: PayloadAction<NodeChange<AnyNode>[]>) => { nodesChanged: (state, action: PayloadAction<NodeChange<AnyNode>[]>) => {
// In v12.7.0, @xyflow/react added a `domAttributes` property to the node data. One DOM attribute is // In v12.7.0, @xyflow/react added a `domAttributes` property to the node data. One DOM attribute is
@@ -588,7 +590,7 @@ export const nodesSlice = createSlice({
} }
node.data.notes = value; node.data.notes = value;
}, },
nodeEditorReset: () => deepClone(initialState), nodeEditorReset: () => getInitialState(),
workflowNameChanged: (state, action: PayloadAction<string>) => { workflowNameChanged: (state, action: PayloadAction<string>) => {
state.name = action.payload; state.name = action.payload;
}, },
@@ -673,7 +675,7 @@ export const nodesSlice = createSlice({
const formFieldInitialValues = getFormFieldInitialValues(workflowExtra.form, nodes); const formFieldInitialValues = getFormFieldInitialValues(workflowExtra.form, nodes);
return { return {
...deepClone(initialState), ...getInitialState(),
...deepClone(workflowExtra), ...deepClone(workflowExtra),
formFieldInitialValues, formFieldInitialValues,
nodes: nodes.map((node) => ({ ...SHARED_NODE_PROPERTIES, ...node })), nodes: nodes.map((node) => ({ ...SHARED_NODE_PROPERTIES, ...node })),
@@ -758,7 +760,7 @@ export const {
workflowLoaded, workflowLoaded,
undo, undo,
redo, redo,
} = nodesSlice.actions; } = slice.actions;
export const $cursorPos = atom<XYPosition | null>(null); export const $cursorPos = atom<XYPosition | null>(null);
export const $templates = atom<Templates>({}); export const $templates = atom<Templates>({});
@@ -775,21 +777,6 @@ export const $lastEdgeUpdateMouseEvent = atom<MouseEvent | null>(null);
export const $viewport = atom<Viewport>({ x: 0, y: 0, zoom: 1 }); export const $viewport = atom<Viewport>({ x: 0, y: 0, zoom: 1 });
export const $addNodeCmdk = atom(false); export const $addNodeCmdk = atom(false);
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrateNodesState = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
return state;
};
export const nodesPersistConfig: PersistConfig<NodesState> = {
name: nodesSlice.name,
initialState: initialState,
migrate: migrateNodesState,
persistDenylist: [],
};
type NodeSelectionAction = { type NodeSelectionAction = {
type: ReturnType<typeof nodesChanged>['type']; type: ReturnType<typeof nodesChanged>['type'];
payload: NodeSelectionChange[]; payload: NodeSelectionChange[];
@@ -893,10 +880,10 @@ const isHighFrequencyWorkflowDetailsAction = isAnyOf(
// a note in a notes node, we don't want to create a new undo group for every keystroke. // a note in a notes node, we don't want to create a new undo group for every keystroke.
const isHighFrequencyNodeScopedAction = isAnyOf(nodeLabelChanged, nodeNotesChanged, notesNodeValueChanged); const isHighFrequencyNodeScopedAction = isAnyOf(nodeLabelChanged, nodeNotesChanged, notesNodeValueChanged);
export const nodesUndoableConfig: UndoableOptions<NodesState, UnknownAction> = { const reduxUndoOptions: UndoableOptions<NodesState, UnknownAction> = {
limit: 64, limit: 64,
undoType: nodesSlice.actions.undo.type, undoType: slice.actions.undo.type,
redoType: nodesSlice.actions.redo.type, redoType: slice.actions.redo.type,
groupBy: (action, _state, _history) => { groupBy: (action, _state, _history) => {
if (isHighFrequencyFieldChangeAction(action)) { if (isHighFrequencyFieldChangeAction(action)) {
// Group by type, node id and field name // Group by type, node id and field name
@@ -928,7 +915,7 @@ export const nodesUndoableConfig: UndoableOptions<NodesState, UnknownAction> = {
}, },
filter: (action, _state, _history) => { filter: (action, _state, _history) => {
// Ignore all actions from other slices // Ignore all actions from other slices
if (!action.type.startsWith(nodesSlice.name)) { if (!action.type.startsWith(slice.name)) {
return false; return false;
} }
// Ignore actions that only select or deselect nodes and edges // Ignore actions that only select or deselect nodes and edges
@@ -943,6 +930,24 @@ export const nodesUndoableConfig: UndoableOptions<NodesState, UnknownAction> = {
}, },
}; };
export const nodesSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zNodesState,
getInitialState,
persistConfig: {
migrate: (state) => {
assert(isPlainObject(state));
if (!('_version' in state)) {
state._version = 1;
}
return zNodesState.parse(state);
},
},
undoableConfig: {
reduxUndoOptions,
},
};
// The form builder's initial values are based on the current values of the node fields in the workflow. // The form builder's initial values are based on the current values of the node fields in the workflow.
export const getFormFieldInitialValues = (form: BuilderForm, nodes: NodesState['nodes']) => { export const getFormFieldInitialValues = (form: BuilderForm, nodes: NodesState['nodes']) => {
const formFieldInitialValues: Record<string, StatefulFieldValue> = {}; const formFieldInitialValues: Record<string, StatefulFieldValue> = {};

View File

@@ -1,7 +1,8 @@
import type { HandleType } from '@xyflow/react'; import type { HandleType } from '@xyflow/react';
import type { FieldInputTemplate, FieldOutputTemplate, StatefulFieldValue } from 'features/nodes/types/field'; import { type FieldInputTemplate, type FieldOutputTemplate, zStatefulFieldValue } from 'features/nodes/types/field';
import type { AnyEdge, AnyNode, InvocationTemplate, NodeExecutionState } from 'features/nodes/types/invocation'; import { type InvocationTemplate, type NodeExecutionState, zAnyEdge, zAnyNode } from 'features/nodes/types/invocation';
import type { WorkflowV3 } from 'features/nodes/types/workflow'; import { zWorkflowV3 } from 'features/nodes/types/workflow';
import z from 'zod';
export type Templates = Record<string, InvocationTemplate>; export type Templates = Record<string, InvocationTemplate>;
export type NodeExecutionStates = Record<string, NodeExecutionState | undefined>; export type NodeExecutionStates = Record<string, NodeExecutionState | undefined>;
@@ -13,11 +14,13 @@ export type PendingConnection = {
fieldTemplate: FieldInputTemplate | FieldOutputTemplate; fieldTemplate: FieldInputTemplate | FieldOutputTemplate;
}; };
export type WorkflowMode = 'edit' | 'view'; export const zWorkflowMode = z.enum(['edit', 'view']);
export type WorkflowMode = z.infer<typeof zWorkflowMode>;
export type NodesState = { export const zNodesState = z.object({
_version: 1; _version: z.literal(1),
nodes: AnyNode[]; nodes: z.array(zAnyNode),
edges: AnyEdge[]; edges: z.array(zAnyEdge),
formFieldInitialValues: Record<string, StatefulFieldValue>; formFieldInitialValues: z.record(z.string(), zStatefulFieldValue),
} & Omit<WorkflowV3, 'nodes' | 'edges' | 'is_published'>; ...zWorkflowV3.omit({ nodes: true, edges: true, is_published: true }).shape,
});
export type NodesState = z.infer<typeof zNodesState>;

View File

@@ -1,34 +1,43 @@
import type { PayloadAction, Selector } from '@reduxjs/toolkit'; import type { PayloadAction, Selector } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit'; import { createSelector, createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import type { WorkflowMode } from 'features/nodes/store/types'; import type { SliceConfig } from 'app/store/types';
import { type WorkflowMode, zWorkflowMode } from 'features/nodes/store/types';
import type { WorkflowCategory } from 'features/nodes/types/workflow'; import type { WorkflowCategory } from 'features/nodes/types/workflow';
import { atom, computed } from 'nanostores'; import { atom, computed } from 'nanostores';
import type { SQLiteDirection, WorkflowRecordOrderBy } from 'services/api/types'; import {
type SQLiteDirection,
type WorkflowRecordOrderBy,
zSQLiteDirection,
zWorkflowRecordOrderBy,
} from 'services/api/types';
import z from 'zod';
export type WorkflowLibraryView = 'recent' | 'yours' | 'private' | 'shared' | 'defaults' | 'published'; const zWorkflowLibraryView = z.enum(['recent', 'yours', 'private', 'shared', 'defaults', 'published']);
export type WorkflowLibraryView = z.infer<typeof zWorkflowLibraryView>;
type WorkflowLibraryState = { const zWorkflowLibraryState = z.object({
mode: WorkflowMode; mode: zWorkflowMode,
view: WorkflowLibraryView; view: zWorkflowLibraryView,
orderBy: WorkflowRecordOrderBy; orderBy: zWorkflowRecordOrderBy,
direction: SQLiteDirection; direction: zSQLiteDirection,
searchTerm: string; searchTerm: z.string(),
selectedTags: string[]; selectedTags: z.array(z.string()),
}; });
type WorkflowLibraryState = z.infer<typeof zWorkflowLibraryState>;
const initialWorkflowLibraryState: WorkflowLibraryState = { const getInitialState = (): WorkflowLibraryState => ({
mode: 'view', mode: 'view',
searchTerm: '', searchTerm: '',
orderBy: 'opened_at', orderBy: 'opened_at',
direction: 'DESC', direction: 'DESC',
selectedTags: [], selectedTags: [],
view: 'defaults', view: 'defaults',
}; });
export const workflowLibrarySlice = createSlice({ const slice = createSlice({
name: 'workflowLibrary', name: 'workflowLibrary',
initialState: initialWorkflowLibraryState, initialState: getInitialState(),
reducers: { reducers: {
workflowModeChanged: (state, action: PayloadAction<WorkflowMode>) => { workflowModeChanged: (state, action: PayloadAction<WorkflowMode>) => {
state.mode = action.payload; state.mode = action.payload;
@@ -73,16 +82,15 @@ export const {
workflowLibraryTagToggled, workflowLibraryTagToggled,
workflowLibraryTagsReset, workflowLibraryTagsReset,
workflowLibraryViewChanged, workflowLibraryViewChanged,
} = workflowLibrarySlice.actions; } = slice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */ export const workflowLibrarySliceConfig: SliceConfig<typeof slice> = {
const migrateWorkflowLibraryState = (state: any): any => state; slice,
schema: zWorkflowLibraryState,
export const workflowLibraryPersistConfig: PersistConfig<WorkflowLibraryState> = { getInitialState,
name: workflowLibrarySlice.name, persistConfig: {
initialState: initialWorkflowLibraryState, migrate: (state) => zWorkflowLibraryState.parse(state),
migrate: migrateWorkflowLibraryState, },
persistDenylist: [],
}; };
const selectWorkflowLibrarySlice = (state: RootState) => state.workflowLibrary; const selectWorkflowLibrarySlice = (state: RootState) => state.workflowLibrary;

View File

@@ -1,8 +1,10 @@
import type { PayloadAction } from '@reduxjs/toolkit'; import type { PayloadAction } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit'; import { createSelector, createSlice } from '@reduxjs/toolkit';
import { SelectionMode } from '@xyflow/react'; import type { RootState } from 'app/store/store';
import type { PersistConfig, RootState } from 'app/store/store'; import type { SliceConfig } from 'app/store/types';
import { isPlainObject } from 'es-toolkit';
import type { Selector } from 'react-redux'; import type { Selector } from 'react-redux';
import { assert } from 'tsafe';
import z from 'zod'; import z from 'zod';
export const zLayeringStrategy = z.enum(['network-simplex', 'longest-path']); export const zLayeringStrategy = z.enum(['network-simplex', 'longest-path']);
@@ -11,25 +13,28 @@ export const zLayoutDirection = z.enum(['TB', 'LR']);
type LayoutDirection = z.infer<typeof zLayoutDirection>; type LayoutDirection = z.infer<typeof zLayoutDirection>;
export const zNodeAlignment = z.enum(['UL', 'UR', 'DL', 'DR']); export const zNodeAlignment = z.enum(['UL', 'UR', 'DL', 'DR']);
type NodeAlignment = z.infer<typeof zNodeAlignment>; type NodeAlignment = z.infer<typeof zNodeAlignment>;
const zSelectionMode = z.enum(['partial', 'full']);
export type WorkflowSettingsState = { const zWorkflowSettingsState = z.object({
_version: 1; _version: z.literal(1),
shouldShowMinimapPanel: boolean; shouldShowMinimapPanel: z.boolean(),
layeringStrategy: LayeringStrategy; layeringStrategy: zLayeringStrategy,
nodeSpacing: number; nodeSpacing: z.number(),
layerSpacing: number; layerSpacing: z.number(),
layoutDirection: LayoutDirection; layoutDirection: zLayoutDirection,
shouldValidateGraph: boolean; shouldValidateGraph: z.boolean(),
shouldAnimateEdges: boolean; shouldAnimateEdges: z.boolean(),
nodeAlignment: NodeAlignment; nodeAlignment: zNodeAlignment,
nodeOpacity: number; nodeOpacity: z.number(),
shouldSnapToGrid: boolean; shouldSnapToGrid: z.boolean(),
shouldColorEdges: boolean; shouldColorEdges: z.boolean(),
shouldShowEdgeLabels: boolean; shouldShowEdgeLabels: z.boolean(),
selectionMode: SelectionMode; selectionMode: zSelectionMode,
}; });
const initialState: WorkflowSettingsState = { export type WorkflowSettingsState = z.infer<typeof zWorkflowSettingsState>;
const getInitialState = (): WorkflowSettingsState => ({
_version: 1, _version: 1,
shouldShowMinimapPanel: true, shouldShowMinimapPanel: true,
layeringStrategy: 'network-simplex', layeringStrategy: 'network-simplex',
@@ -43,12 +48,12 @@ const initialState: WorkflowSettingsState = {
shouldColorEdges: true, shouldColorEdges: true,
shouldShowEdgeLabels: false, shouldShowEdgeLabels: false,
nodeOpacity: 1, nodeOpacity: 1,
selectionMode: SelectionMode.Partial, selectionMode: 'partial',
}; });
export const workflowSettingsSlice = createSlice({ const slice = createSlice({
name: 'workflowSettings', name: 'workflowSettings',
initialState, initialState: getInitialState(),
reducers: { reducers: {
shouldShowMinimapPanelChanged: (state, action: PayloadAction<boolean>) => { shouldShowMinimapPanelChanged: (state, action: PayloadAction<boolean>) => {
state.shouldShowMinimapPanel = action.payload; state.shouldShowMinimapPanel = action.payload;
@@ -87,7 +92,7 @@ export const workflowSettingsSlice = createSlice({
state.nodeAlignment = action.payload; state.nodeAlignment = action.payload;
}, },
selectionModeChanged: (state, action: PayloadAction<boolean>) => { selectionModeChanged: (state, action: PayloadAction<boolean>) => {
state.selectionMode = action.payload ? SelectionMode.Full : SelectionMode.Partial; state.selectionMode = action.payload ? 'full' : 'partial';
}, },
}, },
}); });
@@ -106,21 +111,21 @@ export const {
shouldValidateGraphChanged, shouldValidateGraphChanged,
nodeOpacityChanged, nodeOpacityChanged,
selectionModeChanged, selectionModeChanged,
} = workflowSettingsSlice.actions; } = slice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */ export const workflowSettingsSliceConfig: SliceConfig<typeof slice> = {
const migrateWorkflowSettingsState = (state: any): any => { slice,
if (!('_version' in state)) { schema: zWorkflowSettingsState,
state._version = 1; getInitialState,
} persistConfig: {
return state; migrate: (state) => {
}; assert(isPlainObject(state));
if (!('_version' in state)) {
export const workflowSettingsPersistConfig: PersistConfig<WorkflowSettingsState> = { state._version = 1;
name: workflowSettingsSlice.name, }
initialState, return zWorkflowSettingsState.parse(state);
migrate: migrateWorkflowSettingsState, },
persistDenylist: [], },
}; };
export const selectWorkflowSettingsSlice = (state: RootState) => state.workflowSettings; export const selectWorkflowSettingsSlice = (state: RootState) => state.workflowSettings;

View File

@@ -92,7 +92,7 @@ export const zMainModelBase = z.enum([
]); ]);
type MainModelBase = z.infer<typeof zMainModelBase>; type MainModelBase = z.infer<typeof zMainModelBase>;
export const isMainModelBase = (base: unknown): base is MainModelBase => zMainModelBase.safeParse(base).success; export const isMainModelBase = (base: unknown): base is MainModelBase => zMainModelBase.safeParse(base).success;
const zModelType = z.enum([ export const zModelType = z.enum([
'main', 'main',
'vae', 'vae',
'lora', 'lora',

View File

@@ -43,7 +43,7 @@ export const zNotesNodeData = z.object({
isOpen: z.boolean(), isOpen: z.boolean(),
notes: z.string(), notes: z.string(),
}); });
const _zCurrentImageNodeData = z.object({ const zCurrentImageNodeData = z.object({
id: z.string().trim().min(1), id: z.string().trim().min(1),
type: z.literal('current_image'), type: z.literal('current_image'),
label: z.string(), label: z.string(),
@@ -52,12 +52,35 @@ const _zCurrentImageNodeData = z.object({
export type NotesNodeData = z.infer<typeof zNotesNodeData>; export type NotesNodeData = z.infer<typeof zNotesNodeData>;
export type InvocationNodeData = z.infer<typeof zInvocationNodeData>; export type InvocationNodeData = z.infer<typeof zInvocationNodeData>;
type CurrentImageNodeData = z.infer<typeof _zCurrentImageNodeData>; type CurrentImageNodeData = z.infer<typeof zCurrentImageNodeData>;
export type InvocationNode = Node<InvocationNodeData, 'invocation'>; const zInvocationNodeValidationSchema = z.looseObject({
export type NotesNode = Node<NotesNodeData, 'notes'>; type: z.literal('invocation'),
export type CurrentImageNode = Node<CurrentImageNodeData, 'current_image'>; data: zInvocationNodeData,
export type AnyNode = InvocationNode | NotesNode | CurrentImageNode; });
const zInvocationNode = z.custom<Node<InvocationNodeData, 'invocation'>>(
(val) => zInvocationNodeValidationSchema.safeParse(val).success
);
export type InvocationNode = z.infer<typeof zInvocationNode>;
const zNotesNodeValidationSchema = z.looseObject({
type: z.literal('notes'),
data: zNotesNodeData,
});
const zNotesNode = z.custom<Node<NotesNodeData, 'notes'>>((val) => zNotesNodeValidationSchema.safeParse(val).success);
export type NotesNode = z.infer<typeof zNotesNode>;
const zCurrentImageNodeValidationSchema = z.looseObject({
type: z.literal('current_image'),
data: zCurrentImageNodeData,
});
const zCurrentImageNode = z.custom<Node<CurrentImageNodeData, 'current_image'>>(
(val) => zCurrentImageNodeValidationSchema.safeParse(val).success
);
export type CurrentImageNode = z.infer<typeof zCurrentImageNode>;
export const zAnyNode = z.union([zInvocationNode, zNotesNode, zCurrentImageNode]);
export type AnyNode = z.infer<typeof zAnyNode>;
export const isInvocationNode = (node?: AnyNode | null): node is InvocationNode => export const isInvocationNode = (node?: AnyNode | null): node is InvocationNode =>
Boolean(node && node.type === 'invocation'); Boolean(node && node.type === 'invocation');
@@ -83,13 +106,29 @@ export type NodeExecutionState = z.infer<typeof _zNodeExecutionState>;
// #endregion // #endregion
// #region Edges // #region Edges
const _zInvocationNodeEdgeCollapsedData = z.object({ const zDefaultInvocationNodeEdgeValidationSchema = z.looseObject({
type: z.literal('default'),
});
const zDefaultInvocationNodeEdge = z.custom<Edge<Record<string, never>, 'default'>>(
(val) => zDefaultInvocationNodeEdgeValidationSchema.safeParse(val).success
);
export type DefaultInvocationNodeEdge = z.infer<typeof zDefaultInvocationNodeEdge>;
const zInvocationNodeEdgeCollapsedData = z.object({
count: z.number().int().min(1), count: z.number().int().min(1),
}); });
type InvocationNodeEdgeCollapsedData = z.infer<typeof _zInvocationNodeEdgeCollapsedData>; const zInvocationNodeEdgeCollapsedValidationSchema = z.looseObject({
export type DefaultInvocationNodeEdge = Edge<Record<string, never>, 'default'>; type: z.literal('default'),
export type CollapsedInvocationNodeEdge = Edge<InvocationNodeEdgeCollapsedData, 'collapsed'>; data: zInvocationNodeEdgeCollapsedData,
export type AnyEdge = DefaultInvocationNodeEdge | CollapsedInvocationNodeEdge; });
type InvocationNodeEdgeCollapsedData = z.infer<typeof zInvocationNodeEdgeCollapsedData>;
const zCollapsedInvocationNodeEdge = z.custom<Edge<InvocationNodeEdgeCollapsedData, 'collapsed'>>(
(val) => zInvocationNodeEdgeCollapsedValidationSchema.safeParse(val).success
);
export type CollapsedInvocationNodeEdge = z.infer<typeof zCollapsedInvocationNodeEdge>;
export const zAnyEdge = z.union([zDefaultInvocationNodeEdge, zCollapsedInvocationNodeEdge]);
export type AnyEdge = z.infer<typeof zAnyEdge>;
// #endregion // #endregion
export const isBatchNodeType = (type: string) => export const isBatchNodeType = (type: string) =>

View File

@@ -1,4 +1,5 @@
import { FormControl, FormLabel } from '@invoke-ai/ui-library'; import { FormControl, FormLabel } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover'; import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { selectBase } from 'features/controlLayers/store/paramsSlice'; import { selectBase } from 'features/controlLayers/store/paramsSlice';
@@ -6,13 +7,35 @@ import { ModelPicker } from 'features/parameters/components/ModelPicker';
import { selectTileControlNetModel, tileControlnetModelChanged } from 'features/parameters/store/upscaleSlice'; import { selectTileControlNetModel, tileControlnetModelChanged } from 'features/parameters/store/upscaleSlice';
import { memo, useCallback, useMemo } from 'react'; import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/api/endpoints/models';
import { useControlNetModels } from 'services/api/hooks/modelsByType'; import { useControlNetModels } from 'services/api/hooks/modelsByType';
import type { ControlNetModelConfig } from 'services/api/types'; import { type ControlNetModelConfig, isControlNetModelConfig } from 'services/api/types';
const selectTileControlNetModelConfig = createSelector(
selectModelConfigsQuery,
selectTileControlNetModel,
(modelConfigs, modelIdentifierField) => {
if (!modelConfigs.data) {
return null;
}
if (!modelIdentifierField) {
return null;
}
const modelConfig = modelConfigsAdapterSelectors.selectById(modelConfigs.data, modelIdentifierField.key);
if (!modelConfig) {
return null;
}
if (!isControlNetModelConfig(modelConfig)) {
return null;
}
return modelConfig;
}
);
const ParamTileControlNetModel = () => { const ParamTileControlNetModel = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const tileControlNetModel = useAppSelector(selectTileControlNetModel); const tileControlNetModel = useAppSelector(selectTileControlNetModelConfig);
const currentBaseModel = useAppSelector(selectBase); const currentBaseModel = useAppSelector(selectBase);
const [modelConfigs, { isLoading }] = useControlNetModels(); const [modelConfigs, { isLoading }] = useControlNetModels();

View File

@@ -1,21 +1,21 @@
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import type { ImageWithDims } from 'features/controlLayers/store/types';
import { selectUpscaleSlice } from 'features/parameters/store/upscaleSlice'; import { selectUpscaleSlice } from 'features/parameters/store/upscaleSlice';
import { selectConfigSlice } from 'features/system/store/configSlice'; import { selectConfigSlice } from 'features/system/store/configSlice';
import { useMemo } from 'react'; import { useMemo } from 'react';
import type { ImageDTO } from 'services/api/types';
const createIsTooLargeToUpscaleSelector = (imageDTO?: ImageDTO | null) => const createIsTooLargeToUpscaleSelector = (imageWithDims?: ImageWithDims | null) =>
createSelector(selectUpscaleSlice, selectConfigSlice, (upscale, config) => { createSelector(selectUpscaleSlice, selectConfigSlice, (upscale, config) => {
const { upscaleModel, scale } = upscale; const { upscaleModel, scale } = upscale;
const { maxUpscaleDimension } = config; const { maxUpscaleDimension } = config;
if (!maxUpscaleDimension || !upscaleModel || !imageDTO) { if (!maxUpscaleDimension || !upscaleModel || !imageWithDims) {
// When these are missing, another warning will be shown // When these are missing, another warning will be shown
return false; return false;
} }
const { width, height } = imageDTO; const { width, height } = imageWithDims;
const maxPixels = maxUpscaleDimension ** 2; const maxPixels = maxUpscaleDimension ** 2;
const upscaledPixels = width * scale * height * scale; const upscaledPixels = width * scale * height * scale;
@@ -23,7 +23,7 @@ const createIsTooLargeToUpscaleSelector = (imageDTO?: ImageDTO | null) =>
return upscaledPixels > maxPixels; return upscaledPixels > maxPixels;
}); });
export const useIsTooLargeToUpscale = (imageDTO?: ImageDTO | null) => { export const useIsTooLargeToUpscale = (imageWithDims?: ImageWithDims | null) => {
const selectIsTooLargeToUpscale = useMemo(() => createIsTooLargeToUpscaleSelector(imageDTO), [imageDTO]); const selectIsTooLargeToUpscale = useMemo(() => createIsTooLargeToUpscaleSelector(imageWithDims), [imageWithDims]);
return useAppSelector(selectIsTooLargeToUpscale); return useAppSelector(selectIsTooLargeToUpscale);
}; };

View File

@@ -1,24 +1,33 @@
import type { PayloadAction, Selector } from '@reduxjs/toolkit'; import type { PayloadAction, Selector } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit'; import { createSelector, createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { isPlainObject } from 'es-toolkit';
import type { ImageWithDims } from 'features/controlLayers/store/types';
import { zImageWithDims } from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common';
import type { ParameterSpandrelImageToImageModel } from 'features/parameters/types/parameterSchemas'; import type { ParameterSpandrelImageToImageModel } from 'features/parameters/types/parameterSchemas';
import type { ControlNetModelConfig, ImageDTO } from 'services/api/types'; import type { ControlNetModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import z from 'zod';
export interface UpscaleState { const zUpscaleState = z.object({
_version: 1; _version: z.literal(2),
upscaleModel: ParameterSpandrelImageToImageModel | null; upscaleModel: zModelIdentifierField.nullable(),
upscaleInitialImage: ImageDTO | null; upscaleInitialImage: zImageWithDims.nullable(),
structure: number; structure: z.number(),
creativity: number; creativity: z.number(),
tileControlnetModel: ControlNetModelConfig | null; tileControlnetModel: zModelIdentifierField.nullable(),
scale: number; scale: z.number(),
postProcessingModel: ParameterSpandrelImageToImageModel | null; postProcessingModel: zModelIdentifierField.nullable(),
tileSize: number; tileSize: z.number(),
tileOverlap: number; tileOverlap: z.number(),
} });
const initialUpscaleState: UpscaleState = { export type UpscaleState = z.infer<typeof zUpscaleState>;
_version: 1,
const getInitialState = (): UpscaleState => ({
_version: 2,
upscaleModel: null, upscaleModel: null,
upscaleInitialImage: null, upscaleInitialImage: null,
structure: 0, structure: 0,
@@ -28,16 +37,19 @@ const initialUpscaleState: UpscaleState = {
postProcessingModel: null, postProcessingModel: null,
tileSize: 1024, tileSize: 1024,
tileOverlap: 128, tileOverlap: 128,
}; });
export const upscaleSlice = createSlice({ const slice = createSlice({
name: 'upscale', name: 'upscale',
initialState: initialUpscaleState, initialState: getInitialState(),
reducers: { reducers: {
upscaleModelChanged: (state, action: PayloadAction<ParameterSpandrelImageToImageModel | null>) => { upscaleModelChanged: (state, action: PayloadAction<ParameterSpandrelImageToImageModel | null>) => {
state.upscaleModel = action.payload; const result = zUpscaleState.shape.upscaleModel.safeParse(action.payload);
if (result.success) {
state.upscaleModel = result.data;
}
}, },
upscaleInitialImageChanged: (state, action: PayloadAction<ImageDTO | null>) => { upscaleInitialImageChanged: (state, action: PayloadAction<ImageWithDims | null>) => {
state.upscaleInitialImage = action.payload; state.upscaleInitialImage = action.payload;
}, },
structureChanged: (state, action: PayloadAction<number>) => { structureChanged: (state, action: PayloadAction<number>) => {
@@ -47,13 +59,19 @@ export const upscaleSlice = createSlice({
state.creativity = action.payload; state.creativity = action.payload;
}, },
tileControlnetModelChanged: (state, action: PayloadAction<ControlNetModelConfig | null>) => { tileControlnetModelChanged: (state, action: PayloadAction<ControlNetModelConfig | null>) => {
state.tileControlnetModel = action.payload; const result = zUpscaleState.shape.tileControlnetModel.safeParse(action.payload);
if (result.success) {
state.tileControlnetModel = result.data;
}
}, },
scaleChanged: (state, action: PayloadAction<number>) => { scaleChanged: (state, action: PayloadAction<number>) => {
state.scale = action.payload; state.scale = action.payload;
}, },
postProcessingModelChanged: (state, action: PayloadAction<ParameterSpandrelImageToImageModel | null>) => { postProcessingModelChanged: (state, action: PayloadAction<ParameterSpandrelImageToImageModel | null>) => {
state.postProcessingModel = action.payload; const result = zUpscaleState.shape.postProcessingModel.safeParse(action.payload);
if (result.success) {
state.postProcessingModel = result.data;
}
}, },
tileSizeChanged: (state, action: PayloadAction<number>) => { tileSizeChanged: (state, action: PayloadAction<number>) => {
state.tileSize = action.payload; state.tileSize = action.payload;
@@ -74,21 +92,33 @@ export const {
postProcessingModelChanged, postProcessingModelChanged,
tileSizeChanged, tileSizeChanged,
tileOverlapChanged, tileOverlapChanged,
} = upscaleSlice.actions; } = slice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */ export const upscaleSliceConfig: SliceConfig<typeof slice> = {
const migrateUpscaleState = (state: any): any => { slice,
if (!('_version' in state)) { schema: zUpscaleState,
state._version = 1; getInitialState,
} persistConfig: {
return state; migrate: (state) => {
}; assert(isPlainObject(state));
if (!('_version' in state)) {
export const upscalePersistConfig: PersistConfig<UpscaleState> = { state._version = 1;
name: upscaleSlice.name, }
initialState: initialUpscaleState, if (state._version === 1) {
migrate: migrateUpscaleState, state._version = 2;
persistDenylist: [], // Migrate from v1 to v2: upscaleInitialImage was an ImageDTO, now it's an ImageWithDims
if (state.upscaleInitialImage) {
const { image_name, width, height } = state.upscaleInitialImage;
state.upscaleInitialImage = {
image_name,
width,
height,
};
}
}
return zUpscaleState.parse(state);
},
},
}; };
export const selectUpscaleSlice = (state: RootState) => state.upscale; export const selectUpscaleSlice = (state: RootState) => state.upscale;

View File

@@ -1,24 +1,27 @@
import type { PayloadAction, Selector } from '@reduxjs/toolkit'; import type { PayloadAction, Selector } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit'; import { createSelector, createSlice } from '@reduxjs/toolkit';
import type { RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import z from 'zod';
interface QueueState { const zQueueState = z.object({
listCursor: number | undefined; listCursor: z.number().optional(),
listPriority: number | undefined; listPriority: z.number().optional(),
selectedQueueItem: string | undefined; selectedQueueItem: z.string().optional(),
resumeProcessorOnEnqueue: boolean; resumeProcessorOnEnqueue: z.boolean(),
} });
type QueueState = z.infer<typeof zQueueState>;
const initialQueueState: QueueState = { const getInitialState = (): QueueState => ({
listCursor: undefined, listCursor: undefined,
listPriority: undefined, listPriority: undefined,
selectedQueueItem: undefined, selectedQueueItem: undefined,
resumeProcessorOnEnqueue: true, resumeProcessorOnEnqueue: true,
}; });
export const queueSlice = createSlice({ const slice = createSlice({
name: 'queue', name: 'queue',
initialState: initialQueueState, initialState: getInitialState(),
reducers: { reducers: {
listCursorChanged: (state, action: PayloadAction<number | undefined>) => { listCursorChanged: (state, action: PayloadAction<number | undefined>) => {
state.listCursor = action.payload; state.listCursor = action.payload;
@@ -33,7 +36,13 @@ export const queueSlice = createSlice({
}, },
}); });
export const { listCursorChanged, listPriorityChanged, listParamsReset } = queueSlice.actions; export const { listCursorChanged, listPriorityChanged, listParamsReset } = slice.actions;
export const queueSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zQueueState,
getInitialState,
};
const selectQueueSlice = (state: RootState) => state.queue; const selectQueueSlice = (state: RootState) => state.queue;
const createQueueSelector = <T>(selector: Selector<QueueState, T>) => createSelector(selectQueueSlice, selector); const createQueueSelector = <T>(selector: Selector<QueueState, T>) => createSelector(selectQueueSlice, selector);

View File

@@ -1,6 +1,7 @@
import { Flex, Text } from '@invoke-ai/ui-library'; import { Flex, Text } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { UploadImageIconButton } from 'common/hooks/useImageUploadButton'; import { UploadImageIconButton } from 'common/hooks/useImageUploadButton';
import { imageDTOToImageWithDims } from 'features/controlLayers/store/util';
import type { SetUpscaleInitialImageDndTargetData } from 'features/dnd/dnd'; import type { SetUpscaleInitialImageDndTargetData } from 'features/dnd/dnd';
import { setUpscaleInitialImageDndTarget } from 'features/dnd/dnd'; import { setUpscaleInitialImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget'; import { DndDropTarget } from 'features/dnd/DndDropTarget';
@@ -10,11 +11,13 @@ import { selectUpscaleInitialImage, upscaleInitialImageChanged } from 'features/
import { t } from 'i18next'; import { t } from 'i18next';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
import { PiArrowCounterClockwiseBold } from 'react-icons/pi'; import { PiArrowCounterClockwiseBold } from 'react-icons/pi';
import { useImageDTO } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types'; import type { ImageDTO } from 'services/api/types';
export const UpscaleInitialImage = () => { export const UpscaleInitialImage = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const imageDTO = useAppSelector(selectUpscaleInitialImage); const upscaleInitialImage = useAppSelector(selectUpscaleInitialImage);
const imageDTO = useImageDTO(upscaleInitialImage?.image_name);
const dndTargetData = useMemo<SetUpscaleInitialImageDndTargetData>( const dndTargetData = useMemo<SetUpscaleInitialImageDndTargetData>(
() => setUpscaleInitialImageDndTarget.getData(), () => setUpscaleInitialImageDndTarget.getData(),
[] []
@@ -26,7 +29,7 @@ export const UpscaleInitialImage = () => {
const onUpload = useCallback( const onUpload = useCallback(
(imageDTO: ImageDTO) => { (imageDTO: ImageDTO) => {
dispatch(upscaleInitialImageChanged(imageDTO)); dispatch(upscaleInitialImageChanged(imageDTOToImageWithDims(imageDTO)));
}, },
[dispatch] [dispatch]
); );

View File

@@ -31,8 +31,10 @@ export const UpscaleWarning = () => {
const validModel = modelConfigs.find((cnetModel) => { const validModel = modelConfigs.find((cnetModel) => {
return cnetModel.base === model?.base && cnetModel.name.toLowerCase().includes('tile'); return cnetModel.base === model?.base && cnetModel.name.toLowerCase().includes('tile');
}); });
dispatch(tileControlnetModelChanged(validModel || null)); if (tileControlnetModel?.key !== validModel?.key) {
}, [model?.base, modelConfigs, dispatch]); dispatch(tileControlnetModelChanged(validModel || null));
}
}, [dispatch, model?.base, modelConfigs, tileControlnetModel?.key]);
const isBaseModelCompatible = useMemo(() => { const isBaseModelCompatible = useMemo(() => {
return model && ['sd-1', 'sdxl'].includes(model.base); return model && ['sd-1', 'sdxl'].includes(model.base);

View File

@@ -1,23 +1,33 @@
import type { PayloadAction, Selector } from '@reduxjs/toolkit'; import type { PayloadAction, Selector } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit'; import { createSelector, createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import { deepClone } from 'common/util/deepClone'; import type { SliceConfig } from 'app/store/types';
import { isPlainObject } from 'es-toolkit';
import { paramsReset } from 'features/controlLayers/store/paramsSlice'; import { paramsReset } from 'features/controlLayers/store/paramsSlice';
import { atom } from 'nanostores'; import { atom } from 'nanostores';
import { stylePresetsApi } from 'services/api/endpoints/stylePresets'; import { stylePresetsApi } from 'services/api/endpoints/stylePresets';
import { assert } from 'tsafe';
import z from 'zod';
import type { StylePresetState } from './types'; const zStylePresetState = z.object({
activeStylePresetId: z.string().nullable(),
searchTerm: z.string(),
viewMode: z.boolean(),
showPromptPreviews: z.boolean(),
});
const initialState: StylePresetState = { type StylePresetState = z.infer<typeof zStylePresetState>;
const getInitialState = (): StylePresetState => ({
activeStylePresetId: null, activeStylePresetId: null,
searchTerm: '', searchTerm: '',
viewMode: false, viewMode: false,
showPromptPreviews: false, showPromptPreviews: false,
}; });
export const stylePresetSlice = createSlice({ const slice = createSlice({
name: 'stylePreset', name: 'stylePreset',
initialState: initialState, initialState: getInitialState(),
reducers: { reducers: {
activeStylePresetIdChanged: (state, action: PayloadAction<string | null>) => { activeStylePresetIdChanged: (state, action: PayloadAction<string | null>) => {
state.activeStylePresetId = action.payload; state.activeStylePresetId = action.payload;
@@ -34,7 +44,7 @@ export const stylePresetSlice = createSlice({
}, },
extraReducers(builder) { extraReducers(builder) {
builder.addCase(paramsReset, () => { builder.addCase(paramsReset, () => {
return deepClone(initialState); return getInitialState();
}); });
builder.addMatcher(stylePresetsApi.endpoints.deleteStylePreset.matchFulfilled, (state, action) => { builder.addMatcher(stylePresetsApi.endpoints.deleteStylePreset.matchFulfilled, (state, action) => {
if (state.activeStylePresetId === null) { if (state.activeStylePresetId === null) {
@@ -58,21 +68,21 @@ export const stylePresetSlice = createSlice({
}); });
export const { activeStylePresetIdChanged, searchTermChanged, viewModeChanged, showPromptPreviewsChanged } = export const { activeStylePresetIdChanged, searchTermChanged, viewModeChanged, showPromptPreviewsChanged } =
stylePresetSlice.actions; slice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */ export const stylePresetSliceConfig: SliceConfig<typeof slice> = {
const migrateStylePresetState = (state: any): any => { slice,
if (!('_version' in state)) { schema: zStylePresetState,
state._version = 1; getInitialState,
} persistConfig: {
return state; migrate: (state) => {
}; assert(isPlainObject(state));
if (!('_version' in state)) {
export const stylePresetPersistConfig: PersistConfig<StylePresetState> = { state._version = 1;
name: stylePresetSlice.name, }
initialState, return zStylePresetState.parse(state);
migrate: migrateStylePresetState, },
persistDenylist: [], },
}; };
export const selectStylePresetSlice = (state: RootState) => state.stylePreset; export const selectStylePresetSlice = (state: RootState) => state.stylePreset;

View File

@@ -1,6 +0,0 @@
export type StylePresetState = {
activeStylePresetId: string | null;
searchTerm: string;
viewMode: boolean;
showPromptPreviews: boolean;
};

View File

@@ -14,11 +14,11 @@ import {
Switch, Switch,
Text, Text,
} from '@invoke-ai/ui-library'; } from '@invoke-ai/ui-library';
import { useClearStorage } from 'app/contexts/clear-storage-context';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover'; import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent'; import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { buildUseBoolean } from 'common/hooks/useBoolean'; import { buildUseBoolean } from 'common/hooks/useBoolean';
import { useClearStorage } from 'common/hooks/useClearStorage';
import { selectShouldUseCPUNoise, shouldUseCpuNoiseChanged } from 'features/controlLayers/store/paramsSlice'; import { selectShouldUseCPUNoise, shouldUseCpuNoiseChanged } from 'features/controlLayers/store/paramsSlice';
import { useRefreshAfterResetModal } from 'features/system/components/SettingsModal/RefreshAfterResetModal'; import { useRefreshAfterResetModal } from 'features/system/components/SettingsModal/RefreshAfterResetModal';
import { SettingsDeveloperLogIsEnabled } from 'features/system/components/SettingsModal/SettingsDeveloperLogIsEnabled'; import { SettingsDeveloperLogIsEnabled } from 'features/system/components/SettingsModal/SettingsDeveloperLogIsEnabled';

View File

@@ -1,193 +1,25 @@
import type { PayloadAction, Selector } from '@reduxjs/toolkit'; import type { PayloadAction, Selector } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit'; import { createSelector, createSlice } from '@reduxjs/toolkit';
import type { RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import type { AppConfig, NumericalParameterConfig, PartialAppConfig } from 'app/types/invokeai'; import type { SliceConfig } from 'app/store/types';
import { getDefaultAppConfig, type PartialAppConfig, zAppConfig } from 'app/types/invokeai';
import { merge } from 'es-toolkit/compat'; import { merge } from 'es-toolkit/compat';
import z from 'zod';
const baseDimensionConfig: NumericalParameterConfig = { const zConfigState = z.object({
initial: 512, // determined by model selection, unused in practice ...zAppConfig.shape,
sliderMin: 64, didLoad: z.boolean(),
sliderMax: 1536, });
numberInputMin: 64, type ConfigState = z.infer<typeof zConfigState>;
numberInputMax: 4096,
fineStep: 8,
coarseStep: 64,
};
const initialConfigState: AppConfig & { didLoad: boolean } = { const getInitialState = (): ConfigState => ({
...getDefaultAppConfig(),
didLoad: false, didLoad: false,
isLocal: true, });
shouldUpdateImagesOnConnect: false,
shouldFetchMetadataFromApi: false,
allowPrivateBoards: false,
allowPrivateStylePresets: false,
allowClientSideUpload: false,
allowPublishWorkflows: false,
allowPromptExpansion: false,
shouldShowCredits: false,
disabledTabs: [],
disabledFeatures: ['lightbox', 'faceRestore', 'batches'],
disabledSDFeatures: ['variation', 'symmetry', 'hires', 'perlinNoise', 'noiseThreshold'],
nodesAllowlist: undefined,
nodesDenylist: undefined,
sd: {
disabledControlNetModels: [],
disabledControlNetProcessors: [],
iterations: {
initial: 1,
sliderMin: 1,
sliderMax: 1000,
numberInputMin: 1,
numberInputMax: 10000,
fineStep: 1,
coarseStep: 1,
},
width: { ...baseDimensionConfig },
height: { ...baseDimensionConfig },
boundingBoxWidth: { ...baseDimensionConfig },
boundingBoxHeight: { ...baseDimensionConfig },
scaledBoundingBoxWidth: { ...baseDimensionConfig },
scaledBoundingBoxHeight: { ...baseDimensionConfig },
scheduler: 'dpmpp_3m_k',
vaePrecision: 'fp32',
steps: {
initial: 30,
sliderMin: 1,
sliderMax: 100,
numberInputMin: 1,
numberInputMax: 500,
fineStep: 1,
coarseStep: 1,
},
guidance: {
initial: 7,
sliderMin: 1,
sliderMax: 20,
numberInputMin: 1,
numberInputMax: 200,
fineStep: 0.1,
coarseStep: 0.5,
},
img2imgStrength: {
initial: 0.7,
sliderMin: 0,
sliderMax: 1,
numberInputMin: 0,
numberInputMax: 1,
fineStep: 0.01,
coarseStep: 0.05,
},
canvasCoherenceStrength: {
initial: 0.3,
sliderMin: 0,
sliderMax: 1,
numberInputMin: 0,
numberInputMax: 1,
fineStep: 0.01,
coarseStep: 0.05,
},
hrfStrength: {
initial: 0.45,
sliderMin: 0,
sliderMax: 1,
numberInputMin: 0,
numberInputMax: 1,
fineStep: 0.01,
coarseStep: 0.05,
},
canvasCoherenceEdgeSize: {
initial: 16,
sliderMin: 0,
sliderMax: 128,
numberInputMin: 0,
numberInputMax: 1024,
fineStep: 8,
coarseStep: 16,
},
cfgRescaleMultiplier: {
initial: 0,
sliderMin: 0,
sliderMax: 0.99,
numberInputMin: 0,
numberInputMax: 0.99,
fineStep: 0.05,
coarseStep: 0.1,
},
clipSkip: {
initial: 0,
sliderMin: 0,
sliderMax: 12, // determined by model selection, unused in practice
numberInputMin: 0,
numberInputMax: 12, // determined by model selection, unused in practice
fineStep: 1,
coarseStep: 1,
},
infillPatchmatchDownscaleSize: {
initial: 1,
sliderMin: 1,
sliderMax: 10,
numberInputMin: 1,
numberInputMax: 10,
fineStep: 1,
coarseStep: 1,
},
infillTileSize: {
initial: 32,
sliderMin: 16,
sliderMax: 64,
numberInputMin: 16,
numberInputMax: 256,
fineStep: 1,
coarseStep: 1,
},
maskBlur: {
initial: 16,
sliderMin: 0,
sliderMax: 128,
numberInputMin: 0,
numberInputMax: 512,
fineStep: 1,
coarseStep: 1,
},
ca: {
weight: {
initial: 1,
sliderMin: 0,
sliderMax: 2,
numberInputMin: -1,
numberInputMax: 2,
fineStep: 0.01,
coarseStep: 0.05,
},
},
dynamicPrompts: {
maxPrompts: {
initial: 100,
sliderMin: 1,
sliderMax: 1000,
numberInputMin: 1,
numberInputMax: 10000,
fineStep: 1,
coarseStep: 10,
},
},
},
flux: {
guidance: {
initial: 4,
sliderMin: 2,
sliderMax: 6,
numberInputMin: 1,
numberInputMax: 20,
fineStep: 0.1,
coarseStep: 0.5,
},
},
};
export const configSlice = createSlice({ const slice = createSlice({
name: 'config', name: 'config',
initialState: initialConfigState, initialState: getInitialState(),
reducers: { reducers: {
configChanged: (state, action: PayloadAction<PartialAppConfig>) => { configChanged: (state, action: PayloadAction<PartialAppConfig>) => {
merge(state, action.payload); merge(state, action.payload);
@@ -196,11 +28,16 @@ export const configSlice = createSlice({
}, },
}); });
export const { configChanged } = configSlice.actions; export const { configChanged } = slice.actions;
export const configSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zConfigState,
getInitialState,
};
export const selectConfigSlice = (state: RootState) => state.config; export const selectConfigSlice = (state: RootState) => state.config;
const createConfigSelector = <T>(selector: Selector<typeof initialConfigState, T>) => const createConfigSelector = <T>(selector: Selector<ConfigState, T>) => createSelector(selectConfigSlice, selector);
createSelector(selectConfigSlice, selector);
export const selectWidthConfig = createConfigSelector((config) => config.sd.width); export const selectWidthConfig = createConfigSelector((config) => config.sd.width);
export const selectHeightConfig = createConfigSelector((config) => config.sd.height); export const selectHeightConfig = createConfigSelector((config) => config.sd.height);

View File

@@ -3,12 +3,15 @@ import { createSelector, createSlice } from '@reduxjs/toolkit';
import type { LogNamespace } from 'app/logging/logger'; import type { LogNamespace } from 'app/logging/logger';
import { zLogNamespace } from 'app/logging/logger'; import { zLogNamespace } from 'app/logging/logger';
import { EMPTY_ARRAY } from 'app/store/constants'; import { EMPTY_ARRAY } from 'app/store/constants';
import type { PersistConfig, RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { isPlainObject } from 'es-toolkit';
import { uniq } from 'es-toolkit/compat'; import { uniq } from 'es-toolkit/compat';
import { assert } from 'tsafe';
import type { Language, SystemState } from './types'; import { type Language, type SystemState, zSystemState } from './types';
const initialSystemState: SystemState = { const getInitialState = (): SystemState => ({
_version: 2, _version: 2,
shouldConfirmOnDelete: true, shouldConfirmOnDelete: true,
shouldAntialiasProgressImage: false, shouldAntialiasProgressImage: false,
@@ -23,11 +26,11 @@ const initialSystemState: SystemState = {
logNamespaces: [...zLogNamespace.options], logNamespaces: [...zLogNamespace.options],
shouldShowInvocationProgressDetail: false, shouldShowInvocationProgressDetail: false,
shouldHighlightFocusedRegions: false, shouldHighlightFocusedRegions: false,
}; });
export const systemSlice = createSlice({ const slice = createSlice({
name: 'system', name: 'system',
initialState: initialSystemState, initialState: getInitialState(),
reducers: { reducers: {
setShouldConfirmOnDelete: (state, action: PayloadAction<boolean>) => { setShouldConfirmOnDelete: (state, action: PayloadAction<boolean>) => {
state.shouldConfirmOnDelete = action.payload; state.shouldConfirmOnDelete = action.payload;
@@ -89,25 +92,25 @@ export const {
shouldConfirmOnNewSessionToggled, shouldConfirmOnNewSessionToggled,
setShouldShowInvocationProgressDetail, setShouldShowInvocationProgressDetail,
setShouldHighlightFocusedRegions, setShouldHighlightFocusedRegions,
} = systemSlice.actions; } = slice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */ export const systemSliceConfig: SliceConfig<typeof slice> = {
const migrateSystemState = (state: any): any => { slice,
if (!('_version' in state)) { schema: zSystemState,
state._version = 1; getInitialState,
} persistConfig: {
if (state._version === 1) { migrate: (state) => {
state.language = (state as SystemState).language.replace('_', '-'); assert(isPlainObject(state));
state._version = 2; if (!('_version' in state)) {
} state._version = 1;
return state; }
}; if (state._version === 1) {
state.language = (state as SystemState).language.replace('_', '-');
export const systemPersistConfig: PersistConfig<SystemState> = { state._version = 2;
name: systemSlice.name, }
initialState: initialSystemState, return zSystemState.parse(state);
migrate: migrateSystemState, },
persistDenylist: [], },
}; };
export const selectSystemSlice = (state: RootState) => state.system; export const selectSystemSlice = (state: RootState) => state.system;

View File

@@ -1,4 +1,4 @@
import type { LogLevel, LogNamespace } from 'app/logging/logger'; import { zLogLevel, zLogNamespace } from 'app/logging/logger';
import { z } from 'zod'; import { z } from 'zod';
const zLanguage = z.enum([ const zLanguage = z.enum([
@@ -29,19 +29,20 @@ const zLanguage = z.enum([
export type Language = z.infer<typeof zLanguage>; export type Language = z.infer<typeof zLanguage>;
export const isLanguage = (v: unknown): v is Language => zLanguage.safeParse(v).success; export const isLanguage = (v: unknown): v is Language => zLanguage.safeParse(v).success;
export interface SystemState { export const zSystemState = z.object({
_version: 2; _version: z.literal(2),
shouldConfirmOnDelete: boolean; shouldConfirmOnDelete: z.boolean(),
shouldAntialiasProgressImage: boolean; shouldAntialiasProgressImage: z.boolean(),
shouldConfirmOnNewSession: boolean; shouldConfirmOnNewSession: z.boolean(),
language: Language; language: zLanguage,
shouldUseNSFWChecker: boolean; shouldUseNSFWChecker: z.boolean(),
shouldUseWatermarker: boolean; shouldUseWatermarker: z.boolean(),
shouldEnableInformationalPopovers: boolean; shouldEnableInformationalPopovers: z.boolean(),
shouldEnableModelDescriptions: boolean; shouldEnableModelDescriptions: z.boolean(),
logIsEnabled: boolean; logIsEnabled: z.boolean(),
logLevel: LogLevel; logLevel: zLogLevel,
logNamespaces: LogNamespace[]; logNamespaces: z.array(zLogNamespace),
shouldShowInvocationProgressDetail: boolean; shouldShowInvocationProgressDetail: z.boolean(),
shouldHighlightFocusedRegions: boolean; shouldHighlightFocusedRegions: z.boolean(),
} });
export type SystemState = z.infer<typeof zSystemState>;

View File

@@ -1,6 +1,7 @@
import { Box, Button, ButtonGroup, Flex, Grid, Heading, Icon, Text } from '@invoke-ai/ui-library'; import { Box, Button, ButtonGroup, Flex, Grid, Heading, Icon, Text } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useImageUploadButton } from 'common/hooks/useImageUploadButton'; import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
import { imageDTOToImageWithDims } from 'features/controlLayers/store/util';
import { setUpscaleInitialImageDndTarget } from 'features/dnd/dnd'; import { setUpscaleInitialImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget'; import { DndDropTarget } from 'features/dnd/DndDropTarget';
import { import {
@@ -37,7 +38,7 @@ export const UpscalingLaunchpadPanel = memo(() => {
const onUpload = useCallback( const onUpload = useCallback(
(imageDTO: ImageDTO) => { (imageDTO: ImageDTO) => {
dispatch(upscaleInitialImageChanged(imageDTO)); dispatch(upscaleInitialImageChanged(imageDTOToImageWithDims(imageDTO)));
}, },
[dispatch] [dispatch]
); );

View File

@@ -1,11 +1,13 @@
import type { PayloadAction } from '@reduxjs/toolkit'; import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { isPlainObject } from 'es-toolkit';
import { assert } from 'tsafe';
import type { UIState } from './uiTypes'; import { getInitialUIState, type UIState, zUIState } from './uiTypes';
import { getInitialUIState } from './uiTypes';
export const uiSlice = createSlice({ const slice = createSlice({
name: 'ui', name: 'ui',
initialState: getInitialUIState(), initialState: getInitialUIState(),
reducers: { reducers: {
@@ -81,29 +83,30 @@ export const {
textAreaSizesStateChanged, textAreaSizesStateChanged,
dockviewStorageKeyChanged, dockviewStorageKeyChanged,
pickerCompactViewStateChanged, pickerCompactViewStateChanged,
} = uiSlice.actions; } = slice.actions;
export const selectUiSlice = (state: RootState) => state.ui; export const selectUiSlice = (state: RootState) => state.ui;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */ export const uiSliceConfig: SliceConfig<typeof slice> = {
const migrateUIState = (state: any): any => { slice,
if (!('_version' in state)) { schema: zUIState,
state._version = 1; getInitialState: getInitialUIState,
} persistConfig: {
if (state._version === 1) { migrate: (state) => {
state.activeTab = 'generation'; assert(isPlainObject(state));
state._version = 2; if (!('_version' in state)) {
} state._version = 1;
if (state._version === 2) { }
state.activeTab = 'canvas'; if (state._version === 1) {
state._version = 3; state.activeTab = 'generation';
} state._version = 2;
return state; }
}; if (state._version === 2) {
state.activeTab = 'canvas';
export const uiPersistConfig: PersistConfig<UIState> = { state._version = 3;
name: uiSlice.name, }
initialState: getInitialUIState(), return zUIState.parse(state);
migrate: migrateUIState, },
persistDenylist: ['shouldShowImageDetails'], persistDenylist: ['shouldShowImageDetails'],
},
}; };

View File

@@ -1,8 +1,7 @@
import { deepClone } from 'common/util/deepClone';
import { isPlainObject } from 'es-toolkit'; import { isPlainObject } from 'es-toolkit';
import { z } from 'zod'; import { z } from 'zod';
const zTabName = z.enum(['generate', 'canvas', 'upscaling', 'workflows', 'models', 'queue']); export const zTabName = z.enum(['generate', 'canvas', 'upscaling', 'workflows', 'models', 'queue']);
export type TabName = z.infer<typeof zTabName>; export type TabName = z.infer<typeof zTabName>;
const zPartialDimensions = z.object({ const zPartialDimensions = z.object({
@@ -13,18 +12,28 @@ const zPartialDimensions = z.object({
const zSerializable = z.any().refine(isPlainObject); const zSerializable = z.any().refine(isPlainObject);
export type Serializable = z.infer<typeof zSerializable>; export type Serializable = z.infer<typeof zSerializable>;
const zUIState = z.object({ export const zUIState = z.object({
_version: z.literal(3).default(3), _version: z.literal(3),
activeTab: zTabName.default('generate'), activeTab: zTabName,
shouldShowImageDetails: z.boolean().default(false), shouldShowImageDetails: z.boolean(),
shouldShowProgressInViewer: z.boolean().default(true), shouldShowProgressInViewer: z.boolean(),
accordions: z.record(z.string(), z.boolean()).default(() => ({})), accordions: z.record(z.string(), z.boolean()),
expanders: z.record(z.string(), z.boolean()).default(() => ({})), expanders: z.record(z.string(), z.boolean()),
textAreaSizes: z.record(z.string(), zPartialDimensions).default({}), textAreaSizes: z.record(z.string(), zPartialDimensions),
panels: z.record(z.string(), zSerializable).default({}), panels: z.record(z.string(), zSerializable),
shouldShowNotificationV2: z.boolean().default(true), shouldShowNotificationV2: z.boolean(),
pickerCompactViewStates: z.record(z.string(), z.boolean()).default(() => ({})), pickerCompactViewStates: z.record(z.string(), z.boolean()),
}); });
const INITIAL_STATE = zUIState.parse({});
export type UIState = z.infer<typeof zUIState>; export type UIState = z.infer<typeof zUIState>;
export const getInitialUIState = (): UIState => deepClone(INITIAL_STATE); export const getInitialUIState = (): UIState => ({
_version: 3 as const,
activeTab: 'generate' as const,
shouldShowImageDetails: false,
shouldShowProgressInViewer: true,
accordions: {},
expanders: {},
textAreaSizes: {},
panels: {},
shouldShowNotificationV2: true,
pickerCompactViewStates: {},
});

View File

@@ -1,5 +1,6 @@
import { $openAPISchemaUrl } from 'app/store/nanostores/openAPISchemaUrl'; import { $openAPISchemaUrl } from 'app/store/nanostores/openAPISchemaUrl';
import type { OpenAPIV3_1 } from 'openapi-types'; import type { OpenAPIV3_1 } from 'openapi-types';
import type { stringify } from 'querystring';
import type { paths } from 'services/api/schema'; import type { paths } from 'services/api/schema';
import type { AppConfig, AppVersion } from 'services/api/types'; import type { AppConfig, AppVersion } from 'services/api/types';
@@ -11,7 +12,8 @@ import { api, buildV1Url } from '..';
* buildAppInfoUrl('some-path') * buildAppInfoUrl('some-path')
* // '/api/v1/app/some-path' * // '/api/v1/app/some-path'
*/ */
const buildAppInfoUrl = (path: string = '') => buildV1Url(`app/${path}`); export const buildAppInfoUrl = (path: string = '', query?: Parameters<typeof stringify>[0]) =>
buildV1Url(`app/${path}`, query);
export const appInfoApi = api.injectEndpoints({ export const appInfoApi = api.injectEndpoints({
endpoints: (build) => ({ endpoints: (build) => ({
@@ -87,6 +89,31 @@ export const appInfoApi = api.injectEndpoints({
}, },
providesTags: ['Schema'], providesTags: ['Schema'],
}), }),
getClientStateByKey: build.query<
paths['/api/v1/app/client_state']['get']['responses']['200']['content']['application/json'],
paths['/api/v1/app/client_state']['get']['parameters']['query']
>({
query: () => ({
url: buildAppInfoUrl('client_state'),
method: 'GET',
}),
}),
setClientStateByKey: build.mutation<
paths['/api/v1/app/client_state']['post']['responses']['200']['content']['application/json'],
paths['/api/v1/app/client_state']['post']['requestBody']['content']['application/json']
>({
query: (body) => ({
url: buildAppInfoUrl('client_state'),
method: 'POST',
body,
}),
}),
deleteClientState: build.mutation<void, void>({
query: () => ({
url: buildAppInfoUrl('client_state'),
method: 'DELETE',
}),
}),
}), }),
}); });

View File

@@ -57,13 +57,18 @@ const tagTypes = [
// This is invalidated on reconnect. It should be used for queries that have changing data, // This is invalidated on reconnect. It should be used for queries that have changing data,
// especially related to the queue and generation. // especially related to the queue and generation.
'FetchOnReconnect', 'FetchOnReconnect',
'ClientState',
] as const; ] as const;
export type ApiTagDescription = TagDescription<(typeof tagTypes)[number]>; export type ApiTagDescription = TagDescription<(typeof tagTypes)[number]>;
export const LIST_TAG = 'LIST'; export const LIST_TAG = 'LIST';
export const LIST_ALL_TAG = 'LIST_ALL'; export const LIST_ALL_TAG = 'LIST_ALL';
const dynamicBaseQuery: BaseQueryFn<string | FetchArgs, unknown, FetchBaseQueryError> = (args, api, extraOptions) => { export const getBaseUrl = (): string => {
const baseUrl = $baseUrl.get(); const baseUrl = $baseUrl.get();
return baseUrl || window.location.href.replace(/\/$/, '');
};
const dynamicBaseQuery: BaseQueryFn<string | FetchArgs, unknown, FetchBaseQueryError> = (args, api, extraOptions) => {
const authToken = $authToken.get(); const authToken = $authToken.get();
const projectId = $projectId.get(); const projectId = $projectId.get();
const isOpenAPIRequest = const isOpenAPIRequest =
@@ -71,7 +76,7 @@ const dynamicBaseQuery: BaseQueryFn<string | FetchArgs, unknown, FetchBaseQueryE
(typeof args === 'string' && args.includes('openapi.json')); (typeof args === 'string' && args.includes('openapi.json'));
const fetchBaseQueryArgs: FetchBaseQueryArgs = { const fetchBaseQueryArgs: FetchBaseQueryArgs = {
baseUrl: baseUrl || window.location.href.replace(/\/$/, ''), baseUrl: getBaseUrl(),
}; };
// When fetching the openapi.json, we need to remove circular references from the JSON. // When fetching the openapi.json, we need to remove circular references from the JSON.

View File

@@ -1164,6 +1164,34 @@ export type paths = {
patch?: never; patch?: never;
trace?: never; trace?: never;
}; };
"/api/v1/app/client_state": {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
/**
* Get Client State By Key
* @description Gets the client state
*/
get: operations["get_client_state_by_key"];
put?: never;
/**
* Set Client State
* @description Sets the client state
*/
post: operations["set_client_state"];
/**
* Delete Client State
* @description Deletes the client state
*/
delete: operations["delete_client_state"];
options?: never;
head?: never;
patch?: never;
trace?: never;
};
"/api/v1/queue/{queue_id}/enqueue_batch": { "/api/v1/queue/{queue_id}/enqueue_batch": {
parameters: { parameters: {
query?: never; query?: never;
@@ -24697,6 +24725,101 @@ export interface operations {
}; };
}; };
}; };
get_client_state_by_key: {
parameters: {
query: {
/** @description Key to get */
key: string;
};
header?: never;
path?: never;
cookie?: never;
};
requestBody?: never;
responses: {
/** @description Successful Response */
200: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["JsonValue"] | null;
};
};
/** @description Validation Error */
422: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["HTTPValidationError"];
};
};
};
};
set_client_state: {
parameters: {
query: {
/** @description Key to set */
key: string;
};
header?: never;
path?: never;
cookie?: never;
};
requestBody: {
content: {
"application/json": components["schemas"]["JsonValue"];
};
};
responses: {
/** @description Successful Response */
200: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": unknown;
};
};
/** @description Validation Error */
422: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["HTTPValidationError"];
};
};
};
};
delete_client_state: {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
requestBody?: never;
responses: {
/** @description Successful Response */
200: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": unknown;
};
};
/** @description Client state deleted */
204: {
headers: {
[name: string]: unknown;
};
content?: never;
};
};
};
enqueue_batch: { enqueue_batch: {
parameters: { parameters: {
query?: never; query?: never;

View File

@@ -1,6 +1,9 @@
import type { Dimensions } from 'features/controlLayers/store/types'; import type { Dimensions } from 'features/controlLayers/store/types';
import type { components, paths } from 'services/api/schema'; import type { components, paths } from 'services/api/schema';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
import type { JsonObject, SetRequired } from 'type-fest'; import type { JsonObject, SetRequired } from 'type-fest';
import z from 'zod';
export type S = components['schemas']; export type S = components['schemas'];
@@ -33,10 +36,36 @@ export type InvocationJSONSchemaExtra = S['UIConfigBase'];
export type AppVersion = S['AppVersion']; export type AppVersion = S['AppVersion'];
export type AppConfig = S['AppConfig']; export type AppConfig = S['AppConfig'];
const zResourceOrigin = z.enum(['internal', 'external']);
type ResourceOrigin = z.infer<typeof zResourceOrigin>;
assert<Equals<ResourceOrigin, S['ResourceOrigin']>>();
const zImageCategory = z.enum(['general', 'mask', 'control', 'user', 'other']);
export type ImageCategory = z.infer<typeof zImageCategory>;
assert<Equals<ImageCategory, S['ImageCategory']>>();
// Images // Images
export type ImageDTO = S['ImageDTO']; const _zImageDTO = z.object({
image_name: z.string(),
image_url: z.string(),
thumbnail_url: z.string(),
image_origin: zResourceOrigin,
image_category: zImageCategory,
width: z.number().int().gt(0),
height: z.number().int().gt(0),
created_at: z.string(),
updated_at: z.string(),
deleted_at: z.string().nullish(),
is_intermediate: z.boolean(),
session_id: z.string().nullish(),
node_id: z.string().nullish(),
starred: z.boolean(),
has_workflow: z.boolean(),
board_id: z.string().nullish(),
});
export type ImageDTO = z.infer<typeof _zImageDTO>;
assert<Equals<ImageDTO, S['ImageDTO']>>();
export type BoardDTO = S['BoardDTO']; export type BoardDTO = S['BoardDTO'];
export type ImageCategory = S['ImageCategory'];
export type OffsetPaginatedResults_ImageDTO_ = S['OffsetPaginatedResults_ImageDTO_']; export type OffsetPaginatedResults_ImageDTO_ = S['OffsetPaginatedResults_ImageDTO_'];
// Models // Models
@@ -298,8 +327,13 @@ export type ModelInstallStatus = S['InstallStatus'];
export type Graph = S['Graph']; export type Graph = S['Graph'];
export type NonNullableGraph = SetRequired<Graph, 'nodes' | 'edges'>; export type NonNullableGraph = SetRequired<Graph, 'nodes' | 'edges'>;
export type Batch = S['Batch']; export type Batch = S['Batch'];
export type WorkflowRecordOrderBy = S['WorkflowRecordOrderBy']; export const zWorkflowRecordOrderBy = z.enum(['name', 'created_at', 'updated_at', 'opened_at']);
export type SQLiteDirection = S['SQLiteDirection']; export type WorkflowRecordOrderBy = z.infer<typeof zWorkflowRecordOrderBy>;
assert<Equals<S['WorkflowRecordOrderBy'], WorkflowRecordOrderBy>>();
export const zSQLiteDirection = z.enum(['ASC', 'DESC']);
export type SQLiteDirection = z.infer<typeof zSQLiteDirection>;
assert<Equals<S['SQLiteDirection'], SQLiteDirection>>();
export type WorkflowRecordListItemWithThumbnailDTO = S['WorkflowRecordListItemWithThumbnailDTO']; export type WorkflowRecordListItemWithThumbnailDTO = S['WorkflowRecordListItemWithThumbnailDTO'];
type KeysOfUnion<T> = T extends T ? keyof T : never; type KeysOfUnion<T> = T extends T ? keyof T : never;

View File

@@ -1,12 +1,12 @@
import { ExternalLink } from '@invoke-ai/ui-library'; import { ExternalLink } from '@invoke-ai/ui-library';
import { isAnyOf } from '@reduxjs/toolkit'; import { isAnyOf } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import { listenerMiddleware } from 'app/store/middleware/listenerMiddleware';
import { socketConnected } from 'app/store/middleware/listenerMiddleware/listeners/socketConnected'; import { socketConnected } from 'app/store/middleware/listenerMiddleware/listeners/socketConnected';
import { $baseUrl } from 'app/store/nanostores/baseUrl'; import { $baseUrl } from 'app/store/nanostores/baseUrl';
import { $bulkDownloadId } from 'app/store/nanostores/bulkDownloadId'; import { $bulkDownloadId } from 'app/store/nanostores/bulkDownloadId';
import { $queueId } from 'app/store/nanostores/queueId'; import { $queueId } from 'app/store/nanostores/queueId';
import type { AppStore } from 'app/store/store'; import type { AppStore } from 'app/store/store';
import { listenerMiddleware } from 'app/store/store';
import { deepClone } from 'common/util/deepClone'; import { deepClone } from 'common/util/deepClone';
import { forEach, isNil, round } from 'es-toolkit/compat'; import { forEach, isNil, round } from 'es-toolkit/compat';
import { import {

View File

@@ -1 +1 @@
__version__ = "6.1.0" __version__ = "6.2.0a1"

View File

@@ -67,6 +67,7 @@ def mock_services() -> InvocationServices:
workflow_thumbnails=None, # type: ignore workflow_thumbnails=None, # type: ignore
model_relationship_records=None, # type: ignore model_relationship_records=None, # type: ignore
model_relationships=None, # type: ignore model_relationships=None, # type: ignore
client_state_persistence=None, # type: ignore
) )