refactor(ui): use new routes for _all_ client state persistence (no override/custom drivers)

This commit is contained in:
psychedelicious
2025-07-30 13:10:34 +10:00
parent 11d68cc646
commit 6784fd5b43
6 changed files with 123 additions and 182 deletions

View File

@@ -2,8 +2,8 @@ import { Box } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { GlobalHookIsolator } from 'app/components/GlobalHookIsolator';
import { GlobalModalIsolator } from 'app/components/GlobalModalIsolator';
import { useClearStorage } from 'app/contexts/clear-storage-context';
import { $didStudioInit, type StudioInitAction } from 'app/hooks/useStudioInitAction';
import { clearStorage } from 'app/store/enhancers/reduxRemember/driver';
import type { PartialAppConfig } from 'app/types/invokeai';
import Loading from 'common/components/Loading/Loading';
import { AppContent } from 'features/ui/components/AppContent';
@@ -21,13 +21,12 @@ interface Props {
const App = ({ config = DEFAULT_CONFIG, studioInitAction }: Props) => {
const didStudioInit = useStore($didStudioInit);
const clearStorage = useClearStorage();
const handleReset = useCallback(() => {
clearStorage();
location.reload();
return false;
}, [clearStorage]);
}, []);
return (
<ThemeLocaleProvider>

View File

@@ -1,12 +1,11 @@
import 'i18n';
import type { Middleware } from '@reduxjs/toolkit';
import { ClearStorageProvider } from 'app/contexts/clear-storage-context';
import type { StudioInitAction } from 'app/hooks/useStudioInitAction';
import { $didStudioInit } from 'app/hooks/useStudioInitAction';
import type { LoggingOverrides } from 'app/logging/logger';
import { $loggingOverrides, configureLogging } from 'app/logging/logger';
import { buildStorage } from 'app/store/enhancers/reduxRemember/driver';
import { addStorageListeners } from 'app/store/enhancers/reduxRemember/driver';
import { $accountSettingsLink } from 'app/store/nanostores/accountSettingsLink';
import { $authToken } from 'app/store/nanostores/authToken';
import { $baseUrl } from 'app/store/nanostores/baseUrl';
@@ -72,14 +71,7 @@ interface Props extends PropsWithChildren {
* If provided, overrides in-app navigation to the model manager
*/
onClickGoToModelManager?: () => void;
storageConfig?: {
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
getItem: (key: string) => Promise<any>;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
setItem: (key: string, value: any) => Promise<any>;
clear: () => Promise<void>;
persistThrottle: number;
};
storagePersistThrottle?: number;
}
const InvokeAIUI = ({
@@ -106,7 +98,7 @@ const InvokeAIUI = ({
loggingOverrides,
onClickGoToModelManager,
whatsNew,
storageConfig,
storagePersistThrottle = 2000,
}: Props) => {
useLayoutEffect(() => {
/*
@@ -319,21 +311,13 @@ const InvokeAIUI = ({
};
}, [isDebugging]);
const storage = useMemo(() => buildStorage(storageConfig), [storageConfig]);
useEffect(() => {
const storageCleanup = storage.registerListeners();
return () => {
storageCleanup();
};
}, [storage]);
useEffect(() => addStorageListeners(), []);
const store = useMemo(() => {
return createStore({
driver: storage.reduxRememberDriver,
persistThrottle: storageConfig?.persistThrottle ?? 2000,
persistThrottle: storagePersistThrottle,
});
}, [storage.reduxRememberDriver, storageConfig?.persistThrottle]);
}, [storagePersistThrottle]);
useEffect(() => {
$store.set(store);
@@ -350,13 +334,11 @@ const InvokeAIUI = ({
return (
<React.StrictMode>
<ClearStorageProvider value={storage.clearStorage}>
<Provider store={store}>
<React.Suspense fallback={<Loading />}>
<App config={config} studioInitAction={studioInitAction} />
</React.Suspense>
</Provider>
</ClearStorageProvider>
<Provider store={store}>
<React.Suspense fallback={<Loading />}>
<App config={config} studioInitAction={studioInitAction} />
</React.Suspense>
</Provider>
</React.StrictMode>
);
};

View File

@@ -1,10 +0,0 @@
import { createContext, useContext } from 'react';
const ClearStorageContext = createContext<() => void>(() => {});
export const ClearStorageProvider = ClearStorageContext.Provider;
export const useClearStorage = () => {
const context = useContext(ClearStorageContext);
return context;
};

View File

@@ -1,160 +1,131 @@
import { logger } from 'app/logging/logger';
import { StorageError } from 'app/store/enhancers/reduxRemember/errors';
import { $projectId } from 'app/store/nanostores/projectId';
import type { Driver as ReduxRememberDriver } from 'redux-remember';
import { getBaseUrl } from 'services/api';
import { buildAppInfoUrl } from 'services/api/endpoints/appInfo';
import { $queueId } from 'app/store/nanostores/queueId';
import type { Driver } from 'redux-remember';
import { buildV1Url, getBaseUrl } from 'services/api';
const log = logger('system');
const getUrl = (key?: string) => {
const getClientStateStorageURL = (operation: 'get_by_key' | 'set_by_key' | 'delete', key?: string) => {
const baseUrl = getBaseUrl();
const query: Record<string, string> = {};
if (key) {
query['key'] = key;
}
const path = buildAppInfoUrl('client_state', query);
const path = buildV1Url(`client_state/${$queueId.get()}/${operation}`, query);
const url = `${baseUrl}/${path}`;
return url;
};
const defaultGetItem = async (key: string): Promise<string | undefined> => {
const url = getUrl(key);
const res = await fetch(url, { method: 'GET' });
if (!res.ok) {
throw new Error(`Response status: ${res.status}`);
}
return res.json();
};
// Persistence happens per slice. To track when persistence is in progress, maintain a ref count, incrementing
// it when a slice is being persisted and decrementing it when the persistence is done.
let persistRefCount = 0;
const defaultSetItem = async (key: string, value: string): Promise<string> => {
const url = getUrl(key);
const res = await fetch(url, { method: 'POST', body: value });
if (!res.ok) {
throw new Error(`Response status: ${res.status}`);
}
return res.json();
};
// Keep track of the last persisted state for each key to avoid unnecessary network requests.
//
// `redux-remember` persists individual slices of state, so we can implicity denylist a slice by not giving it a
// persist config.
//
// However, we may need to avoid persisting individual _fields_ of a slice. `redux-remember` does not provide a
// way to do this directly.
//
// To accomplish this, we add a layer of logic on top of the `redux-remember`. In the state serializer function
// provided to `redux-remember`, we can omit certain fields from the state that we do not want to persist. See
// the implementation in `store.ts` for this logic.
//
// This logic is unknown to `redux-remember`. When an omitted field changes, it will still attempt to persist the
// whole slice, even if the final, _serialized_ slice value is unchanged.
//
// To avoid unnecessary network requests, we keep track of the last persisted state for each key. If the value to
// be persisted is the same as the last persisted value, we can skip the network request.
const lastPersistedState = new Map<string, string | undefined>();
const defaultClear = async (): Promise<void> => {
const url = getUrl();
const res = await fetch(url, { method: 'DELETE' });
if (!res.ok) {
throw new Error(`Response status: ${res.status}`);
}
};
export const buildStorage = (api?: {
getItem: (key: string) => Promise<string | undefined>;
setItem: (key: string, value: string) => Promise<string>;
clear: () => Promise<void>;
}): {
reduxRememberDriver: ReduxRememberDriver;
clearStorage: () => Promise<void>;
registerListeners: () => () => void;
} => {
const _api = api ?? {
getItem: defaultGetItem,
setItem: defaultSetItem,
clear: defaultClear,
};
// Persistence happens per slice. To track when persistence is in progress, maintain a ref count, incrementing
// it when a slice is being persisted and decrementing it when the persistence is done.
let persistRefCount = 0;
// Keep track of the last persisted state for each key to avoid unnecessary network requests.
//
// `redux-remember` persists individual slices of state, so we can implicity denylist a slice by not giving it a
// persist config.
//
// However, we may need to avoid persisting individual _fields_ of a slice. `redux-remember` does not provide a
// way to do this directly.
//
// To accomplish this, we add a layer of logic on top of the `redux-remember`. In the state serializer function
// provided to `redux-remember`, we can omit certain fields from the state that we do not want to persist. See
// the implementation in `store.ts` for this logic.
//
// This logic is unknown to `redux-remember`. When an omitted field changes, it will still attempt to persist the
// whole slice, even if the final, _serialized_ slice value is unchanged.
//
// To avoid unnecessary network requests, we keep track of the last persisted state for each key. If the value to
// be persisted is the same as the last persisted value, we can skip the network request.
const lastPersistedState = new Map<string, string | undefined>();
const reduxRememberDriver: ReduxRememberDriver = {
getItem: async (key) => {
try {
const value = await _api.getItem(key);
lastPersistedState.set(key, value);
log.trace({ key, last: lastPersistedState.get(key), next: value }, `Getting state for ${key}`);
return value;
} catch (originalError) {
throw new StorageError({
key,
projectId: $projectId.get(),
originalError,
});
export const reduxRememberDriver: Driver = {
getItem: async (key: string) => {
try {
const url = getClientStateStorageURL('get_by_key', key);
const res = await fetch(url, { method: 'GET' });
if (!res.ok) {
throw new Error(`Response status: ${res.status}`);
}
},
setItem: async (key, value) => {
try {
persistRefCount++;
if (lastPersistedState.get(key) === value) {
log.trace(
{ key, last: lastPersistedState.get(key), next: value },
`Skipping persist for ${key} as value is unchanged`
);
return value;
}
log.trace({ key, last: lastPersistedState.get(key), next: value }, `Persisting state for ${key}`);
const resultValue = await _api.setItem(key, value);
lastPersistedState.set(key, resultValue);
return resultValue;
} catch (originalError) {
throw new StorageError({
key,
value,
projectId: $projectId.get(),
originalError,
});
} finally {
persistRefCount--;
if (persistRefCount < 0) {
log.trace('Persist ref count is negative, resetting to 0');
persistRefCount = 0;
}
}
},
};
const clearStorage = async () => {
const value = await res.json();
lastPersistedState.set(key, value);
log.trace({ key, last: lastPersistedState.get(key), next: value }, `Getting state for ${key}`);
return value;
} catch (originalError) {
throw new StorageError({
key,
projectId: $projectId.get(),
originalError,
});
}
},
setItem: async (key: string, value: string) => {
try {
persistRefCount++;
await _api.clear();
} catch {
log.error('Failed to reset client state');
if (lastPersistedState.get(key) === value) {
log.trace(
{ key, last: lastPersistedState.get(key), next: value },
`Skipping persist for ${key} as value is unchanged`
);
return value;
}
log.trace({ key, last: lastPersistedState.get(key), next: value }, `Persisting state for ${key}`);
const url = getClientStateStorageURL('set_by_key', key);
const res = await fetch(url, { method: 'POST', body: value });
if (!res.ok) {
throw new Error(`Response status: ${res.status}`);
}
const resultValue = await res.json();
lastPersistedState.set(key, resultValue);
return resultValue;
} catch (originalError) {
throw new StorageError({
key,
value,
projectId: $projectId.get(),
originalError,
});
} finally {
persistRefCount--;
lastPersistedState.clear();
if (persistRefCount < 0) {
log.trace('Persist ref count is negative, resetting to 0');
persistRefCount = 0;
}
}
};
const registerListeners = () => {
const onBeforeUnload = (e: BeforeUnloadEvent) => {
if (persistRefCount > 0) {
e.preventDefault();
}
};
window.addEventListener('beforeunload', onBeforeUnload);
return () => {
window.removeEventListener('beforeunload', onBeforeUnload);
};
};
return { reduxRememberDriver, clearStorage, registerListeners };
},
};
export const clearStorage = async () => {
try {
persistRefCount++;
const url = getClientStateStorageURL('delete');
const res = await fetch(url, { method: 'POST' });
if (!res.ok) {
throw new Error(`Response status: ${res.status}`);
}
} catch {
log.error('Failed to reset client state');
} finally {
persistRefCount--;
lastPersistedState.clear();
if (persistRefCount < 0) {
log.trace('Persist ref count is negative, resetting to 0');
persistRefCount = 0;
}
}
};
export const addStorageListeners = () => {
const onBeforeUnload = (e: BeforeUnloadEvent) => {
if (persistRefCount > 0) {
e.preventDefault();
}
};
window.addEventListener('beforeunload', onBeforeUnload);
return () => {
window.removeEventListener('beforeunload', onBeforeUnload);
};
};

View File

@@ -40,7 +40,7 @@ import { systemSliceConfig } from 'features/system/store/systemSlice';
import { uiSliceConfig } from 'features/ui/store/uiSlice';
import { diff } from 'jsondiffpatch';
import dynamicMiddlewares from 'redux-dynamic-middlewares';
import type { Driver, SerializeFunction, UnserializeFunction } from 'redux-remember';
import type { SerializeFunction, UnserializeFunction } from 'redux-remember';
import { rememberEnhancer, rememberReducer } from 'redux-remember';
import undoable, { newHistory } from 'redux-undo';
import { serializeError } from 'serialize-error';
@@ -48,6 +48,7 @@ import { api } from 'services/api';
import { authToastMiddleware } from 'services/api/authToastMiddleware';
import type { JsonObject } from 'type-fest';
import { reduxRememberDriver } from './enhancers/reduxRemember/driver';
import { actionSanitizer } from './middleware/devtools/actionSanitizer';
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
@@ -183,7 +184,7 @@ const PERSISTED_KEYS = Object.values(SLICE_CONFIGS)
.filter((sliceConfig) => !!sliceConfig.persistConfig)
.map((sliceConfig) => sliceConfig.slice.reducerPath);
export const createStore = (reduxRememberOptions: { driver: Driver; persistThrottle: number }) =>
export const createStore = (options: { persistThrottle: number }) =>
configureStore({
reducer: rememberedRootReducer,
middleware: (getDefaultMiddleware) =>
@@ -201,8 +202,8 @@ export const createStore = (reduxRememberOptions: { driver: Driver; persistThrot
enhancers: (getDefaultEnhancers) => {
const enhancers = getDefaultEnhancers();
return enhancers.prepend(
rememberEnhancer(reduxRememberOptions.driver, PERSISTED_KEYS, {
persistThrottle: reduxRememberOptions.persistThrottle,
rememberEnhancer(reduxRememberDriver, PERSISTED_KEYS, {
persistThrottle: options.persistThrottle,
serialize,
unserialize,
prefix: '',

View File

@@ -14,7 +14,7 @@ import {
Switch,
Text,
} from '@invoke-ai/ui-library';
import { useClearStorage } from 'app/contexts/clear-storage-context';
import { clearStorage } from 'app/store/enhancers/reduxRemember/driver';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
@@ -115,8 +115,6 @@ const SettingsModal = ({ config = defaultConfig, children }: SettingsModalProps)
dispatch(shouldConfirmOnNewSessionToggled());
}, [dispatch]);
const clearStorage = useClearStorage();
useEffect(() => {
if (settingsModal.isTrue && Boolean(config?.shouldShowClearIntermediates)) {
refetchIntermediatesCount();
@@ -127,7 +125,7 @@ const SettingsModal = ({ config = defaultConfig, children }: SettingsModalProps)
clearStorage();
settingsModal.setFalse();
refreshModal.setTrue();
}, [clearStorage, settingsModal, refreshModal]);
}, [settingsModal, refreshModal]);
const handleChangeShouldConfirmOnDelete = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {