refactor(ui): iterate on persistence

This commit is contained in:
psychedelicious
2025-07-22 16:24:49 +10:00
parent ca0684700e
commit 456205da17
11 changed files with 55 additions and 31 deletions

View File

@@ -1,3 +1,2 @@
export const STORAGE_PREFIX = '@@invokeai-';
export const EMPTY_ARRAY = [];
export const EMPTY_OBJECT = {};

View File

@@ -1,12 +1,29 @@
import { objectEquals } from '@observ33r/object-equals';
import { logger } from 'app/logging/logger';
import { StorageError } from 'app/store/enhancers/reduxRemember/errors';
import { $authToken } from 'app/store/nanostores/authToken';
import { $projectId } from 'app/store/nanostores/projectId';
import { $queueId } from 'app/store/nanostores/queueId';
import { $isPendingPersist } from 'app/store/store';
import { atom } from 'nanostores';
import type { Driver } from 'redux-remember';
import { getBaseUrl } from 'services/api';
import { buildAppInfoUrl } from 'services/api/endpoints/appInfo';
const log = logger('system');
// Persistence happens per slice. To track when persistence is in progress, maintain a ref count, incrementing
// it when a slice is being persisted and decrementing it when the persistence is done.
const $persistRefCount = atom(0);
const inc = () => {
$persistRefCount.set($persistRefCount.get() + 1);
};
const dec = () => {
$persistRefCount.set($persistRefCount.get() - 1);
};
// Keep track of the last persisted state for each key to avoid unnecessary network requests.
const lastPersistedState = new Map<string, unknown>();
const getUrl = (key?: string) => {
const baseUrl = getBaseUrl();
const query: Record<string, string> = {};
@@ -59,6 +76,11 @@ export const serverBackedDriver: Driver = {
},
setItem: async (key, value) => {
try {
inc();
if (objectEquals(lastPersistedState.get(key), value)) {
log.debug(`Skipping persist for key "${key}" as value is unchanged.`);
return value;
}
const url = getUrl(key);
const headers = getHeaders({ 'content-type': 'application/json' });
const res = await fetch(url, { headers, method: 'POST', body: JSON.stringify(value) });
@@ -73,21 +95,31 @@ export const serverBackedDriver: Driver = {
projectId: $projectId.get(),
originalError,
});
} finally {
lastPersistedState.set(key, value);
dec();
}
},
};
export const resetClientState = async () => {
const url = getUrl();
const headers = getHeaders();
const res = await fetch(url, { headers, method: 'DELETE' });
if (!res.ok) {
throw new Error(`Response status: ${res.status}`);
try {
inc();
const url = getUrl();
const headers = getHeaders();
const res = await fetch(url, { headers, method: 'DELETE' });
if (!res.ok) {
throw new Error(`Response status: ${res.status}`);
}
} catch {
log.error('Failed to reset client state');
} finally {
dec();
}
};
window.addEventListener('beforeunload', (e) => {
if ($isPendingPersist.get()) {
if ($persistRefCount.get() > 0) {
e.preventDefault();
}
});

View File

@@ -40,7 +40,6 @@ import { configSliceConfig } from 'features/system/store/configSlice';
import { systemSliceConfig } from 'features/system/store/systemSlice';
import { uiSliceConfig } from 'features/ui/store/uiSlice';
import { diff } from 'jsondiffpatch';
import { atom } from 'nanostores';
import dynamicMiddlewares from 'redux-dynamic-middlewares';
import type { SerializeFunction, UnserializeFunction } from 'redux-remember';
import { rememberEnhancer, rememberReducer } from 'redux-remember';
@@ -61,8 +60,6 @@ export const listenerMiddleware = createListenerMiddleware();
const log = logger('system');
// When adding a slice, add the config to the SLICE_CONFIGS object below, then add the reducer to ALL_REDUCERS.
// Remember to wrap undoable slices in `undoable()`.
const SLICE_CONFIGS = {
[canvasSessionSliceConfig.slice.reducerPath]: canvasSessionSliceConfig,
[canvasSettingsSliceConfig.slice.reducerPath]: canvasSettingsSliceConfig,
@@ -85,6 +82,8 @@ const SLICE_CONFIGS = {
[workflowSettingsSliceConfig.slice.reducerPath]: workflowSettingsSliceConfig,
};
// TS makes it really hard to dynamically create this object :/ so it's just hardcoded here.
// Remember to wrap undoable reducers in `undoable()`!
const ALL_REDUCERS = {
[api.reducerPath]: api.reducer,
[canvasSessionSliceConfig.slice.reducerPath]: canvasSessionSliceConfig.slice.reducer,
@@ -120,12 +119,6 @@ const rootReducer = combineReducers(ALL_REDUCERS);
const rememberedRootReducer = rememberReducer(rootReducer);
export const $isPendingPersist = atom(false);
$isPendingPersist.listen((isPendingPersist) => {
console.log({ isPendingPersist });
});
const unserialize: UnserializeFunction = (data, key) => {
const sliceConfig = SLICE_CONFIGS[key as keyof typeof SLICE_CONFIGS];
if (!sliceConfig?.persistConfig) {
@@ -153,7 +146,7 @@ const unserialize: UnserializeFunction = (data, key) => {
{
persistedData: parsed,
rehydratedData: transformed as JsonObject,
diff: diff(parsed, transformed) as JsonObject, // this is always serializable
diff: diff(parsed, transformed) as JsonObject,
},
`Rehydrated slice "${key}"`
);
@@ -166,6 +159,7 @@ const unserialize: UnserializeFunction = (data, key) => {
state = getInitialState();
}
// Undoable slices must be wrapped in a history!
if (undoableConfig) {
return newHistory([], state, []);
} else {
@@ -183,11 +177,13 @@ const serialize: SerializeFunction = (data, key) => {
sliceConfig.undoableConfig ? data.present : data,
sliceConfig.persistConfig.persistDenylist ?? []
);
return JSON.stringify(result);
};
const PERSISTED_SLICE_CONFIGS = Object.values(SLICE_CONFIGS).filter(({ persistConfig }) => !!persistConfig);
const PERSISTED_KEYS = PERSISTED_SLICE_CONFIGS.map(({ slice }) => slice.name);
const PERSISTED_KEYS = Object.values(SLICE_CONFIGS)
.filter((sliceConfig) => !!sliceConfig.persistConfig)
.map((sliceConfig) => sliceConfig.slice.reducerPath);
export const createStore = (uniqueStoreKey?: string, persist = true) =>
configureStore({
@@ -209,7 +205,7 @@ export const createStore = (uniqueStoreKey?: string, persist = true) =>
if (persist) {
const res = enhancers.prepend(
rememberEnhancer(serverBackedDriver, PERSISTED_KEYS, {
persistDebounce: 3000,
persistThrottle: 2000,
serialize,
unserialize,
prefix: '',

View File

@@ -8,7 +8,7 @@ import { initialState } from './initialState';
const getInitialState = () => deepClone(initialState);
export const slice = createSlice({
const slice = createSlice({
name: 'changeBoardModal',
initialState,
reducers: {

View File

@@ -1703,7 +1703,7 @@ const syncScaledSize = (state: CanvasState) => {
let filter = true;
export const canvasUndoableConfig: UndoableOptions<CanvasState, UnknownAction> = {
const canvasUndoableConfig: UndoableOptions<CanvasState, UnknownAction> = {
limit: 64,
undoType: canvasUndo.type,
redoType: canvasRedo.type,

View File

@@ -37,7 +37,7 @@ type PayloadActionWithId<T = void> = T extends void
} & T
>;
export const slice = createSlice({
const slice = createSlice({
name: 'refImages',
initialState: getInitialRefImagesState(),
reducers: {

View File

@@ -31,7 +31,7 @@ const getInitialState = (): DynamicPromptsState => ({
seedBehaviour: 'PER_ITERATION',
});
export const slice = createSlice({
const slice = createSlice({
name: 'dynamicPrompts',
initialState: getInitialState(),
reducers: {

View File

@@ -31,7 +31,7 @@ const getInitialState = (): UpscaleState => ({
tileOverlap: 128,
});
export const slice = createSlice({
const slice = createSlice({
name: 'upscale',
initialState: getInitialState(),
reducers: {

View File

@@ -15,7 +15,7 @@ const getInitialState = (): StylePresetState => ({
showPromptPreviews: false,
});
export const slice = createSlice({
const slice = createSlice({
name: 'stylePreset',
initialState: getInitialState(),
reducers: {

View File

@@ -1,7 +1,6 @@
import { Flex, Spacer } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { $customNavComponent } from 'app/store/nanostores/customNavComponent';
import { $isPendingPersist } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import InvokeAILogoComponent from 'features/system/components/InvokeAILogoComponent';
import SettingsMenu from 'features/system/components/SettingsModal/SettingsMenu';
@@ -38,7 +37,6 @@ export const VerticalNavBar = memo(() => {
const withWorkflowsTab = useAppSelector(selectWithWorkflowsTab);
const withModelsTab = useAppSelector(selectWithModelsTab);
const withQueueTab = useAppSelector(selectWithQueueTab);
const isPendingPersist = useStore($isPendingPersist);
return (
<Flex flexDir="column" alignItems="center" py={6} ps={4} pe={2} gap={4} minW={0} flexShrink={0}>
@@ -50,7 +48,6 @@ export const VerticalNavBar = memo(() => {
{withWorkflowsTab && <TabButton tab="workflows" icon={<PiFlowArrowBold />} label={t('ui.tabs.workflows')} />}
{withModelsTab && <TabButton tab="models" icon={<PiCubeBold />} label={t('ui.tabs.models')} />}
{withQueueTab && <TabButton tab="queue" icon={<PiQueueBold />} label={t('ui.tabs.queue')} />}
{isPendingPersist && <Flex w={4} h={4} bg="red" />}
</Flex>
<Spacer />
<StatusIndicator />

View File

@@ -6,7 +6,7 @@ import { deepClone } from 'common/util/deepClone';
import { INITIAL_STATE, type UIState } from './uiTypes';
export const getInitialState = (): UIState => deepClone(INITIAL_STATE);
const getInitialState = (): UIState => deepClone(INITIAL_STATE);
const slice = createSlice({
name: 'ui',