mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-13 05:05:21 -05:00
refactor(ui): iterate on persistence
This commit is contained in:
@@ -1,3 +1,2 @@
|
||||
export const STORAGE_PREFIX = '@@invokeai-';
|
||||
export const EMPTY_ARRAY = [];
|
||||
export const EMPTY_OBJECT = {};
|
||||
|
||||
@@ -1,12 +1,29 @@
|
||||
import { objectEquals } from '@observ33r/object-equals';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { StorageError } from 'app/store/enhancers/reduxRemember/errors';
|
||||
import { $authToken } from 'app/store/nanostores/authToken';
|
||||
import { $projectId } from 'app/store/nanostores/projectId';
|
||||
import { $queueId } from 'app/store/nanostores/queueId';
|
||||
import { $isPendingPersist } from 'app/store/store';
|
||||
import { atom } from 'nanostores';
|
||||
import type { Driver } from 'redux-remember';
|
||||
import { getBaseUrl } from 'services/api';
|
||||
import { buildAppInfoUrl } from 'services/api/endpoints/appInfo';
|
||||
|
||||
const log = logger('system');
|
||||
|
||||
// Persistence happens per slice. To track when persistence is in progress, maintain a ref count, incrementing
|
||||
// it when a slice is being persisted and decrementing it when the persistence is done.
|
||||
const $persistRefCount = atom(0);
|
||||
const inc = () => {
|
||||
$persistRefCount.set($persistRefCount.get() + 1);
|
||||
};
|
||||
const dec = () => {
|
||||
$persistRefCount.set($persistRefCount.get() - 1);
|
||||
};
|
||||
|
||||
// Keep track of the last persisted state for each key to avoid unnecessary network requests.
|
||||
const lastPersistedState = new Map<string, unknown>();
|
||||
|
||||
const getUrl = (key?: string) => {
|
||||
const baseUrl = getBaseUrl();
|
||||
const query: Record<string, string> = {};
|
||||
@@ -59,6 +76,11 @@ export const serverBackedDriver: Driver = {
|
||||
},
|
||||
setItem: async (key, value) => {
|
||||
try {
|
||||
inc();
|
||||
if (objectEquals(lastPersistedState.get(key), value)) {
|
||||
log.debug(`Skipping persist for key "${key}" as value is unchanged.`);
|
||||
return value;
|
||||
}
|
||||
const url = getUrl(key);
|
||||
const headers = getHeaders({ 'content-type': 'application/json' });
|
||||
const res = await fetch(url, { headers, method: 'POST', body: JSON.stringify(value) });
|
||||
@@ -73,21 +95,31 @@ export const serverBackedDriver: Driver = {
|
||||
projectId: $projectId.get(),
|
||||
originalError,
|
||||
});
|
||||
} finally {
|
||||
lastPersistedState.set(key, value);
|
||||
dec();
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
export const resetClientState = async () => {
|
||||
const url = getUrl();
|
||||
const headers = getHeaders();
|
||||
const res = await fetch(url, { headers, method: 'DELETE' });
|
||||
if (!res.ok) {
|
||||
throw new Error(`Response status: ${res.status}`);
|
||||
try {
|
||||
inc();
|
||||
const url = getUrl();
|
||||
const headers = getHeaders();
|
||||
const res = await fetch(url, { headers, method: 'DELETE' });
|
||||
if (!res.ok) {
|
||||
throw new Error(`Response status: ${res.status}`);
|
||||
}
|
||||
} catch {
|
||||
log.error('Failed to reset client state');
|
||||
} finally {
|
||||
dec();
|
||||
}
|
||||
};
|
||||
|
||||
window.addEventListener('beforeunload', (e) => {
|
||||
if ($isPendingPersist.get()) {
|
||||
if ($persistRefCount.get() > 0) {
|
||||
e.preventDefault();
|
||||
}
|
||||
});
|
||||
|
||||
@@ -40,7 +40,6 @@ import { configSliceConfig } from 'features/system/store/configSlice';
|
||||
import { systemSliceConfig } from 'features/system/store/systemSlice';
|
||||
import { uiSliceConfig } from 'features/ui/store/uiSlice';
|
||||
import { diff } from 'jsondiffpatch';
|
||||
import { atom } from 'nanostores';
|
||||
import dynamicMiddlewares from 'redux-dynamic-middlewares';
|
||||
import type { SerializeFunction, UnserializeFunction } from 'redux-remember';
|
||||
import { rememberEnhancer, rememberReducer } from 'redux-remember';
|
||||
@@ -61,8 +60,6 @@ export const listenerMiddleware = createListenerMiddleware();
|
||||
const log = logger('system');
|
||||
|
||||
// When adding a slice, add the config to the SLICE_CONFIGS object below, then add the reducer to ALL_REDUCERS.
|
||||
// Remember to wrap undoable slices in `undoable()`.
|
||||
|
||||
const SLICE_CONFIGS = {
|
||||
[canvasSessionSliceConfig.slice.reducerPath]: canvasSessionSliceConfig,
|
||||
[canvasSettingsSliceConfig.slice.reducerPath]: canvasSettingsSliceConfig,
|
||||
@@ -85,6 +82,8 @@ const SLICE_CONFIGS = {
|
||||
[workflowSettingsSliceConfig.slice.reducerPath]: workflowSettingsSliceConfig,
|
||||
};
|
||||
|
||||
// TS makes it really hard to dynamically create this object :/ so it's just hardcoded here.
|
||||
// Remember to wrap undoable reducers in `undoable()`!
|
||||
const ALL_REDUCERS = {
|
||||
[api.reducerPath]: api.reducer,
|
||||
[canvasSessionSliceConfig.slice.reducerPath]: canvasSessionSliceConfig.slice.reducer,
|
||||
@@ -120,12 +119,6 @@ const rootReducer = combineReducers(ALL_REDUCERS);
|
||||
|
||||
const rememberedRootReducer = rememberReducer(rootReducer);
|
||||
|
||||
export const $isPendingPersist = atom(false);
|
||||
|
||||
$isPendingPersist.listen((isPendingPersist) => {
|
||||
console.log({ isPendingPersist });
|
||||
});
|
||||
|
||||
const unserialize: UnserializeFunction = (data, key) => {
|
||||
const sliceConfig = SLICE_CONFIGS[key as keyof typeof SLICE_CONFIGS];
|
||||
if (!sliceConfig?.persistConfig) {
|
||||
@@ -153,7 +146,7 @@ const unserialize: UnserializeFunction = (data, key) => {
|
||||
{
|
||||
persistedData: parsed,
|
||||
rehydratedData: transformed as JsonObject,
|
||||
diff: diff(parsed, transformed) as JsonObject, // this is always serializable
|
||||
diff: diff(parsed, transformed) as JsonObject,
|
||||
},
|
||||
`Rehydrated slice "${key}"`
|
||||
);
|
||||
@@ -166,6 +159,7 @@ const unserialize: UnserializeFunction = (data, key) => {
|
||||
state = getInitialState();
|
||||
}
|
||||
|
||||
// Undoable slices must be wrapped in a history!
|
||||
if (undoableConfig) {
|
||||
return newHistory([], state, []);
|
||||
} else {
|
||||
@@ -183,11 +177,13 @@ const serialize: SerializeFunction = (data, key) => {
|
||||
sliceConfig.undoableConfig ? data.present : data,
|
||||
sliceConfig.persistConfig.persistDenylist ?? []
|
||||
);
|
||||
|
||||
return JSON.stringify(result);
|
||||
};
|
||||
|
||||
const PERSISTED_SLICE_CONFIGS = Object.values(SLICE_CONFIGS).filter(({ persistConfig }) => !!persistConfig);
|
||||
const PERSISTED_KEYS = PERSISTED_SLICE_CONFIGS.map(({ slice }) => slice.name);
|
||||
const PERSISTED_KEYS = Object.values(SLICE_CONFIGS)
|
||||
.filter((sliceConfig) => !!sliceConfig.persistConfig)
|
||||
.map((sliceConfig) => sliceConfig.slice.reducerPath);
|
||||
|
||||
export const createStore = (uniqueStoreKey?: string, persist = true) =>
|
||||
configureStore({
|
||||
@@ -209,7 +205,7 @@ export const createStore = (uniqueStoreKey?: string, persist = true) =>
|
||||
if (persist) {
|
||||
const res = enhancers.prepend(
|
||||
rememberEnhancer(serverBackedDriver, PERSISTED_KEYS, {
|
||||
persistDebounce: 3000,
|
||||
persistThrottle: 2000,
|
||||
serialize,
|
||||
unserialize,
|
||||
prefix: '',
|
||||
|
||||
@@ -8,7 +8,7 @@ import { initialState } from './initialState';
|
||||
|
||||
const getInitialState = () => deepClone(initialState);
|
||||
|
||||
export const slice = createSlice({
|
||||
const slice = createSlice({
|
||||
name: 'changeBoardModal',
|
||||
initialState,
|
||||
reducers: {
|
||||
|
||||
@@ -1703,7 +1703,7 @@ const syncScaledSize = (state: CanvasState) => {
|
||||
|
||||
let filter = true;
|
||||
|
||||
export const canvasUndoableConfig: UndoableOptions<CanvasState, UnknownAction> = {
|
||||
const canvasUndoableConfig: UndoableOptions<CanvasState, UnknownAction> = {
|
||||
limit: 64,
|
||||
undoType: canvasUndo.type,
|
||||
redoType: canvasRedo.type,
|
||||
|
||||
@@ -37,7 +37,7 @@ type PayloadActionWithId<T = void> = T extends void
|
||||
} & T
|
||||
>;
|
||||
|
||||
export const slice = createSlice({
|
||||
const slice = createSlice({
|
||||
name: 'refImages',
|
||||
initialState: getInitialRefImagesState(),
|
||||
reducers: {
|
||||
|
||||
@@ -31,7 +31,7 @@ const getInitialState = (): DynamicPromptsState => ({
|
||||
seedBehaviour: 'PER_ITERATION',
|
||||
});
|
||||
|
||||
export const slice = createSlice({
|
||||
const slice = createSlice({
|
||||
name: 'dynamicPrompts',
|
||||
initialState: getInitialState(),
|
||||
reducers: {
|
||||
|
||||
@@ -31,7 +31,7 @@ const getInitialState = (): UpscaleState => ({
|
||||
tileOverlap: 128,
|
||||
});
|
||||
|
||||
export const slice = createSlice({
|
||||
const slice = createSlice({
|
||||
name: 'upscale',
|
||||
initialState: getInitialState(),
|
||||
reducers: {
|
||||
|
||||
@@ -15,7 +15,7 @@ const getInitialState = (): StylePresetState => ({
|
||||
showPromptPreviews: false,
|
||||
});
|
||||
|
||||
export const slice = createSlice({
|
||||
const slice = createSlice({
|
||||
name: 'stylePreset',
|
||||
initialState: getInitialState(),
|
||||
reducers: {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import { Flex, Spacer } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { $customNavComponent } from 'app/store/nanostores/customNavComponent';
|
||||
import { $isPendingPersist } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import InvokeAILogoComponent from 'features/system/components/InvokeAILogoComponent';
|
||||
import SettingsMenu from 'features/system/components/SettingsModal/SettingsMenu';
|
||||
@@ -38,7 +37,6 @@ export const VerticalNavBar = memo(() => {
|
||||
const withWorkflowsTab = useAppSelector(selectWithWorkflowsTab);
|
||||
const withModelsTab = useAppSelector(selectWithModelsTab);
|
||||
const withQueueTab = useAppSelector(selectWithQueueTab);
|
||||
const isPendingPersist = useStore($isPendingPersist);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" alignItems="center" py={6} ps={4} pe={2} gap={4} minW={0} flexShrink={0}>
|
||||
@@ -50,7 +48,6 @@ export const VerticalNavBar = memo(() => {
|
||||
{withWorkflowsTab && <TabButton tab="workflows" icon={<PiFlowArrowBold />} label={t('ui.tabs.workflows')} />}
|
||||
{withModelsTab && <TabButton tab="models" icon={<PiCubeBold />} label={t('ui.tabs.models')} />}
|
||||
{withQueueTab && <TabButton tab="queue" icon={<PiQueueBold />} label={t('ui.tabs.queue')} />}
|
||||
{isPendingPersist && <Flex w={4} h={4} bg="red" />}
|
||||
</Flex>
|
||||
<Spacer />
|
||||
<StatusIndicator />
|
||||
|
||||
@@ -6,7 +6,7 @@ import { deepClone } from 'common/util/deepClone';
|
||||
|
||||
import { INITIAL_STATE, type UIState } from './uiTypes';
|
||||
|
||||
export const getInitialState = (): UIState => deepClone(INITIAL_STATE);
|
||||
const getInitialState = (): UIState => deepClone(INITIAL_STATE);
|
||||
|
||||
const slice = createSlice({
|
||||
name: 'ui',
|
||||
|
||||
Reference in New Issue
Block a user