refactor(ui): restructure persistence driver creation to support custom drivers

This commit is contained in:
psychedelicious
2025-07-24 13:58:36 +10:00
parent 28e7a83f98
commit 37e25ccbf7
8 changed files with 193 additions and 135 deletions

View File

@@ -26,7 +26,7 @@ i18n.use(initReactI18next).init({
returnNull: false,
});
const store = createStore(undefined, false);
const store = createStore({ getItem: () => {}, setItem: () => {} }, false);
$store.set(store);
$baseUrl.set('http://localhost:9090');

View File

@@ -2,10 +2,10 @@ import { Box } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { GlobalHookIsolator } from 'app/components/GlobalHookIsolator';
import { GlobalModalIsolator } from 'app/components/GlobalModalIsolator';
import { useClearStorage } from 'app/contexts/clear-storage-context';
import { $didStudioInit, type StudioInitAction } from 'app/hooks/useStudioInitAction';
import type { PartialAppConfig } from 'app/types/invokeai';
import Loading from 'common/components/Loading/Loading';
import { useClearStorage } from 'common/hooks/useClearStorage';
import { AppContent } from 'features/ui/components/AppContent';
import { memo, useCallback } from 'react';
import { ErrorBoundary } from 'react-error-boundary';

View File

@@ -1,16 +1,12 @@
import 'i18n';
import type { Middleware } from '@reduxjs/toolkit';
import { ClearStorageProvider } from 'app/contexts/clear-storage-context';
import type { StudioInitAction } from 'app/hooks/useStudioInitAction';
import { $didStudioInit } from 'app/hooks/useStudioInitAction';
import type { LoggingOverrides } from 'app/logging/logger';
import { $loggingOverrides, configureLogging } from 'app/logging/logger';
import {
$resetClientState,
buildDriver,
buildResetClientState,
type StorageDriverApi,
} from 'app/store/enhancers/reduxRemember/driver';
import { buildStorageApi, type StorageDriverApi } from 'app/store/enhancers/reduxRemember/driver';
import { $accountSettingsLink } from 'app/store/nanostores/accountSettingsLink';
import { $authToken } from 'app/store/nanostores/authToken';
import { $baseUrl } from 'app/store/nanostores/baseUrl';
@@ -316,18 +312,18 @@ const InvokeAIUI = ({
};
}, [isDebugging]);
useEffect(() => {
$resetClientState.set(buildResetClientState(storageDriverApi));
const storage = useMemo(() => buildStorageApi(storageDriverApi), [storageDriverApi]);
useEffect(() => {
const storageCleanup = storage.registerListeners();
return () => {
$resetClientState.set(() => {});
storageCleanup();
};
}, [storageDriverApi]);
}, [storage]);
const store = useMemo(() => {
const driver = buildDriver(storageDriverApi);
return createStore(driver);
}, [storageDriverApi]);
return createStore(storage.reduxRememberDriver);
}, [storage.reduxRememberDriver]);
useEffect(() => {
$store.set(store);
@@ -344,11 +340,13 @@ const InvokeAIUI = ({
return (
<React.StrictMode>
<Provider store={store}>
<React.Suspense fallback={<Loading />}>
<App config={config} studioInitAction={studioInitAction} />
</React.Suspense>
</Provider>
<ClearStorageProvider value={storage.clearStorage}>
<Provider store={store}>
<React.Suspense fallback={<Loading />}>
<App config={config} studioInitAction={studioInitAction} />
</React.Suspense>
</Provider>
</ClearStorageProvider>
</React.StrictMode>
);
};

View File

@@ -0,0 +1,10 @@
import { createContext, useContext } from 'react';
const ClearStorageContext = createContext<() => void>(() => {});
export const ClearStorageProvider = ClearStorageContext.Provider;
export const useClearStorage = () => {
const context = useContext(ClearStorageContext);
return context;
};

View File

@@ -1,13 +1,35 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import { logger } from 'app/logging/logger';
import { StorageError } from 'app/store/enhancers/reduxRemember/errors';
import { $authToken } from 'app/store/nanostores/authToken';
import { $projectId } from 'app/store/nanostores/projectId';
import { $queueId } from 'app/store/nanostores/queueId';
import { atom } from 'nanostores';
import type { Driver } from 'redux-remember';
import type { Driver as ReduxRememberDriver } from 'redux-remember';
import { getBaseUrl } from 'services/api';
import { buildAppInfoUrl } from 'services/api/endpoints/appInfo';
// Persistence happens per slice. To track when persistence is in progress, maintain a ref count, incrementing
// it when a slice is being persisted and decrementing it when the persistence is done.
let persistRefCount = 0;
// Keep track of the last persisted state for each key to avoid unnecessary network requests.
//
// `redux-remember` persists individual slices of state, so we can implicity denylist a slice by not giving it a
// persist config.
//
// However, we may need to avoid persisting individual _fields_ of a slice. `redux-remember` does not provide a
// way to do this directly.
//
// To accomplish this, we add a layer of logic on top of the `redux-remember`. In the state serializer function
// provided to `redux-remember`, we can omit certain fields from the state that we do not want to persist. See
// the implementation in `store.ts` for this logic.
//
// This logic is unknown to `redux-remember`. When an omitted field changes, it will still attempt to persist the
// whole slice, even if the final, _serialized_ slice value is unchanged.
//
// To avoid unnecessary network requests, we keep track of the last persisted state for each key. If the value to
// be persisted is the same as the last persisted value, we can skip the network request.
const lastPersistedState = new Map<string, unknown>();
export type StorageDriverApi = {
getItem: (key: string) => Promise<any>;
setItem: (key: string, value: any) => Promise<any>;
@@ -16,55 +38,27 @@ export type StorageDriverApi = {
const log = logger('system');
// Persistence happens per slice. To track when persistence is in progress, maintain a ref count, incrementing
// it when a slice is being persisted and decrementing it when the persistence is done.
let persistRefCount = 0;
const buildOSSServerBackedDriver = (): {
reduxRememberDriver: ReduxRememberDriver;
clearStorage: () => Promise<void>;
registerListeners: () => () => void;
} => {
const getUrl = (key?: string) => {
const baseUrl = getBaseUrl();
const query: Record<string, string> = {};
if (key) {
query['key'] = key;
}
const path = buildAppInfoUrl('client_state', query);
const url = `${baseUrl}/${path}`;
return url;
};
// Keep track of the last persisted state for each key to avoid unnecessary network requests.
const lastPersistedState = new Map<string, unknown>();
const getUrl = (key?: string) => {
const baseUrl = getBaseUrl();
const query: Record<string, string> = {};
if (key) {
query['key'] = key;
}
const queueId = $queueId.get();
if (queueId) {
query['queueId'] = queueId;
}
const path = buildAppInfoUrl('client_state', query);
const url = `${baseUrl}/${path}`;
return url;
};
const getHeaders = (extra?: Record<string, string>) => {
const headers = new Headers();
const authToken = $authToken.get();
if (authToken) {
headers.set('Authorization', `Bearer ${authToken}`);
}
const projectId = $projectId.get();
if (projectId) {
headers.set('project-id', projectId);
}
for (const [key, value] of Object.entries(extra ?? {})) {
headers.set(key, value);
}
return headers;
};
export const buildDriver = (api?: StorageDriverApi): Driver => {
return {
const reduxRememberDriver: ReduxRememberDriver = {
getItem: async (key) => {
try {
if (api) {
log.trace(`Using provided API to get item for key "${key}"`);
return await api.getItem(key);
}
const url = getUrl(key);
const headers = getHeaders();
const res = await fetch(url, { headers, method: 'GET' });
const res = await fetch(url, { method: 'GET' });
if (!res.ok) {
throw new Error(`Response status: ${res.status}`);
}
@@ -81,37 +75,98 @@ export const buildDriver = (api?: StorageDriverApi): Driver => {
setItem: async (key, value) => {
try {
persistRefCount++;
if (api) {
log.trace(`Using provided API to get item for key "${key}"`);
return await api.setItem(key, value);
}
// Deep equality check to avoid noop persist network requests.
//
// `redux-remember` persists individual slices of state, so we can implicity denylist a slice by not giving it a
// persist config.
//
// However, we may need to avoid persisting individual _fields_ of a slice. `redux-remember` does not provide a
// way to do this directly.
//
// To accomplish this, we add a layer of logic on top of the `redux-remember`. In the state serializer function
// provided to `redux-remember`, we can omit certain fields from the state that we do not want to persist. See
// the implementation in `store.ts` for this logic.
//
// This logic is unknown to `redux-remember`. When an omitted field changes, it will still attempt to persist the
// whole slice, even if the final, _serialized_ slice value is unchanged.
//
// To avoid unnecessary network requests, we keep track of the last persisted state for each key. If the value to
// be persisted is the same as the last persisted value, we skip the network request.
if (lastPersistedState.get(key) === value) {
log.trace(`Skipping persist for key "${key}" as value is unchanged.`);
return value;
}
const url = getUrl(key);
const headers = getHeaders({ 'content-type': 'application/json' });
const res = await fetch(url, { headers, method: 'POST', body: JSON.stringify(value) });
const res = await fetch(url, { method: 'POST', body: value });
if (!res.ok) {
throw new Error(`Response status: ${res.status}`);
}
} catch (originalError) {
throw new StorageError({
key,
value,
projectId: $projectId.get(),
originalError,
});
} finally {
persistRefCount--;
if (persistRefCount < 0) {
log.trace('Persist ref count is negative, resetting to 0');
persistRefCount = 0;
}
}
},
};
const clearStorage = async () => {
try {
persistRefCount++;
const url = getUrl();
const res = await fetch(url, { method: 'DELETE' });
if (!res.ok) {
throw new Error(`Response status: ${res.status}`);
}
} catch {
log.error('Failed to reset client state');
} finally {
persistRefCount--;
lastPersistedState.clear();
if (persistRefCount < 0) {
log.trace('Persist ref count is negative, resetting to 0');
persistRefCount = 0;
}
}
};
const registerListeners = () => {
const onBeforeUnload = (e: BeforeUnloadEvent) => {
if (persistRefCount > 0) {
e.preventDefault();
}
};
window.addEventListener('beforeunload', onBeforeUnload);
return () => {
window.removeEventListener('beforeunload', onBeforeUnload);
};
};
return { reduxRememberDriver, clearStorage, registerListeners };
};
const buildCustomDriver = (
api: StorageDriverApi
): {
reduxRememberDriver: ReduxRememberDriver;
clearStorage: () => Promise<void>;
registerListeners: () => () => void;
} => {
const reduxRememberDriver: ReduxRememberDriver = {
getItem: async (key) => {
try {
log.trace(`Getting client state for key "${key}"`);
return await api.getItem(key);
} catch (originalError) {
throw new StorageError({
key,
projectId: $projectId.get(),
originalError,
});
}
},
setItem: async (key, value) => {
try {
persistRefCount++;
if (lastPersistedState.get(key) === value) {
log.trace(`Skipping setting client state for key "${key}" as value is unchanged`);
return value;
}
log.trace(`Setting client state for key "${key}", ${value}`);
await api.setItem(key, value);
lastPersistedState.set(key, value);
return value;
} catch (originalError) {
@@ -124,44 +179,50 @@ export const buildDriver = (api?: StorageDriverApi): Driver => {
} finally {
persistRefCount--;
if (persistRefCount < 0) {
log.warn('Persist ref count is negative, resetting to 0');
log.trace('Persist ref count is negative, resetting to 0');
persistRefCount = 0;
}
}
},
};
};
export const $resetClientState = atom(() => {});
export const buildResetClientState = (api?: StorageDriverApi) => async () => {
try {
persistRefCount++;
if (api) {
log.trace('Using provided API to reset client state');
const clearStorage = async () => {
try {
persistRefCount++;
log.trace('Clearing client state');
await api.clear();
return;
} catch {
log.error('Failed to clear client state');
} finally {
persistRefCount--;
lastPersistedState.clear();
if (persistRefCount < 0) {
log.trace('Persist ref count is negative, resetting to 0');
persistRefCount = 0;
}
}
const url = getUrl();
const headers = getHeaders();
const res = await fetch(url, { headers, method: 'DELETE' });
if (!res.ok) {
throw new Error(`Response status: ${res.status}`);
}
} catch {
log.error('Failed to reset client state');
} finally {
persistRefCount--;
lastPersistedState.clear();
if (persistRefCount < 0) {
log.warn('Persist ref count is negative, resetting to 0');
persistRefCount = 0;
}
}
};
const registerListeners = () => {
const onBeforeUnload = (e: BeforeUnloadEvent) => {
if (persistRefCount > 0) {
e.preventDefault();
}
};
window.addEventListener('beforeunload', onBeforeUnload);
return () => {
window.removeEventListener('beforeunload', onBeforeUnload);
};
};
return { reduxRememberDriver, clearStorage, registerListeners };
};
window.addEventListener('beforeunload', (e) => {
if (persistRefCount > 0) {
e.preventDefault();
export const buildStorageApi = (driverApi?: StorageDriverApi) => {
if (driverApi) {
return buildCustomDriver(driverApi);
} else {
return buildOSSServerBackedDriver();
}
});
};

View File

@@ -231,7 +231,7 @@ export const createStore = (driver: Driver, persist = true) =>
export type AppStore = ReturnType<typeof createStore>;
export type RootState = ReturnType<AppStore['getState']>;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
export type AppThunkDispatch = ThunkDispatch<RootState, any, UnknownAction>;
export type AppDispatch = ReturnType<typeof createStore>['dispatch'];
export type AppGetState = ReturnType<typeof createStore>['getState'];

View File

@@ -1,11 +0,0 @@
import { $resetClientState } from 'app/store/enhancers/reduxRemember/driver';
import { useCallback } from 'react';
export const useClearStorage = () => {
const clearStorage = useCallback(() => {
$resetClientState.get()();
localStorage.clear();
}, []);
return clearStorage;
};

View File

@@ -14,11 +14,11 @@ import {
Switch,
Text,
} from '@invoke-ai/ui-library';
import { useClearStorage } from 'app/contexts/clear-storage-context';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { buildUseBoolean } from 'common/hooks/useBoolean';
import { useClearStorage } from 'common/hooks/useClearStorage';
import { selectShouldUseCPUNoise, shouldUseCpuNoiseChanged } from 'features/controlLayers/store/paramsSlice';
import { useRefreshAfterResetModal } from 'features/system/components/SettingsModal/RefreshAfterResetModal';
import { SettingsDeveloperLogIsEnabled } from 'features/system/components/SettingsModal/SettingsDeveloperLogIsEnabled';