From 7a23d8266fd580c87f6039f3aab0333ad9abfa35 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 29 Jul 2025 19:47:54 +1000 Subject: [PATCH] feat(ui): simpler storage driver impl --- .../web/src/app/components/InvokeAIUI.tsx | 4 +- .../store/enhancers/reduxRemember/driver.ts | 205 ++++++------------ invokeai/frontend/web/src/app/store/store.ts | 5 +- 3 files changed, 66 insertions(+), 148 deletions(-) diff --git a/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx b/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx index 7e420190fc..d6956de260 100644 --- a/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx +++ b/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx @@ -6,7 +6,7 @@ import type { StudioInitAction } from 'app/hooks/useStudioInitAction'; import { $didStudioInit } from 'app/hooks/useStudioInitAction'; import type { LoggingOverrides } from 'app/logging/logger'; import { $loggingOverrides, configureLogging } from 'app/logging/logger'; -import { buildStorageApi } from 'app/store/enhancers/reduxRemember/driver'; +import { buildStorage } from 'app/store/enhancers/reduxRemember/driver'; import { $accountSettingsLink } from 'app/store/nanostores/accountSettingsLink'; import { $authToken } from 'app/store/nanostores/authToken'; import { $baseUrl } from 'app/store/nanostores/baseUrl'; @@ -319,7 +319,7 @@ const InvokeAIUI = ({ }; }, [isDebugging]); - const storage = useMemo(() => buildStorageApi(storageConfig), [storageConfig]); + const storage = useMemo(() => buildStorage(storageConfig), [storageConfig]); useEffect(() => { const storageCleanup = storage.registerListeners(); diff --git a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/driver.ts b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/driver.ts index 764c2c4a11..049ab4a9d2 100644 --- a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/driver.ts +++ b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/driver.ts @@ -1,5 +1,3 @@ -/* eslint-disable @typescript-eslint/no-explicit-any */ - import { logger } from 'app/logging/logger'; import { StorageError } from 'app/store/enhancers/reduxRemember/errors'; import { $projectId } from 'app/store/nanostores/projectId'; @@ -9,11 +7,57 @@ import { buildAppInfoUrl } from 'services/api/endpoints/appInfo'; const log = logger('system'); -const buildOSSServerBackedDriver = (): { +const getUrl = (key?: string) => { + const baseUrl = getBaseUrl(); + const query: Record = {}; + if (key) { + query['key'] = key; + } + const path = buildAppInfoUrl('client_state', query); + const url = `${baseUrl}/${path}`; + return url; +}; + +const defaultGetItem = async (key: string): Promise => { + const url = getUrl(key); + const res = await fetch(url, { method: 'GET' }); + if (!res.ok) { + throw new Error(`Response status: ${res.status}`); + } + return res.json(); +}; + +const defaultSetItem = async (key: string, value: string): Promise => { + const url = getUrl(key); + const res = await fetch(url, { method: 'POST', body: value }); + if (!res.ok) { + throw new Error(`Response status: ${res.status}`); + } + return res.json(); +}; + +const defaultClear = async (): Promise => { + const url = getUrl(); + const res = await fetch(url, { method: 'DELETE' }); + if (!res.ok) { + throw new Error(`Response status: ${res.status}`); + } +}; + +export const buildStorage = (api?: { + getItem: (key: string) => Promise; + setItem: (key: string, value: string) => Promise; + clear: () => Promise; +}): { reduxRememberDriver: ReduxRememberDriver; clearStorage: () => Promise; registerListeners: () => () => void; } => { + const _api = api ?? { + getItem: defaultGetItem, + setItem: defaultSetItem, + clear: defaultClear, + }; // Persistence happens per slice. To track when persistence is in progress, maintain a ref count, incrementing // it when a slice is being persisted and decrementing it when the persistence is done. let persistRefCount = 0; @@ -35,32 +79,15 @@ const buildOSSServerBackedDriver = (): { // // To avoid unnecessary network requests, we keep track of the last persisted state for each key. If the value to // be persisted is the same as the last persisted value, we can skip the network request. - const lastPersistedState = new Map(); - - const getUrl = (key?: string) => { - const baseUrl = getBaseUrl(); - const query: Record = {}; - if (key) { - query['key'] = key; - } - const path = buildAppInfoUrl('client_state', query); - const url = `${baseUrl}/${path}`; - return url; - }; + const lastPersistedState = new Map(); const reduxRememberDriver: ReduxRememberDriver = { getItem: async (key) => { try { - const url = getUrl(key); - const res = await fetch(url, { method: 'GET' }); - if (!res.ok) { - throw new Error(`Response status: ${res.status}`); - } - const text = await res.text(); - if (!lastPersistedState.get(key)) { - lastPersistedState.set(key, text); - } - return JSON.parse(text); + const value = await _api.getItem(key); + lastPersistedState.set(key, value); + log.trace({ key, last: lastPersistedState.get(key), next: value }, `Getting state for ${key}`); + return value; } catch (originalError) { throw new StorageError({ key, @@ -73,20 +100,16 @@ const buildOSSServerBackedDriver = (): { try { persistRefCount++; if (lastPersistedState.get(key) === value) { - log.trace(`Skipping persist for key "${key}" as value is unchanged.`); + log.trace( + { key, last: lastPersistedState.get(key), next: value }, + `Skipping persist for ${key} as value is unchanged` + ); return value; } - const url = getUrl(key); - const headers = new Headers({ - 'Content-Type': 'application/json', - }); - const res = await fetch(url, { method: 'POST', headers, body: value }); - if (!res.ok) { - throw new Error(`Response status: ${res.status}`); - } - - lastPersistedState.set(key, value); - return value; + log.trace({ key, last: lastPersistedState.get(key), next: value }, `Persisting state for ${key}`); + const resultValue = await _api.setItem(key, value); + lastPersistedState.set(key, resultValue); + return resultValue; } catch (originalError) { throw new StorageError({ key, @@ -107,11 +130,7 @@ const buildOSSServerBackedDriver = (): { const clearStorage = async () => { try { persistRefCount++; - const url = getUrl(); - const res = await fetch(url, { method: 'DELETE' }); - if (!res.ok) { - throw new Error(`Response status: ${res.status}`); - } + await _api.clear(); } catch { log.error('Failed to reset client state'); } finally { @@ -139,105 +158,3 @@ const buildOSSServerBackedDriver = (): { return { reduxRememberDriver, clearStorage, registerListeners }; }; - -const buildCustomDriver = (api: { - getItem: (key: string) => Promise; - setItem: (key: string, value: any) => Promise; - clear: () => Promise; -}): { - reduxRememberDriver: ReduxRememberDriver; - clearStorage: () => Promise; - registerListeners: () => () => void; -} => { - // See the comment in `buildOSSServerBackedDriver` for an explanation of this variable. - let persistRefCount = 0; - - // See the comment in `buildOSSServerBackedDriver` for an explanation of this variable. - const lastPersistedState = new Map(); - - const reduxRememberDriver: ReduxRememberDriver = { - getItem: async (key) => { - try { - log.trace(`Getting client state for key "${key}"`); - return await api.getItem(key); - } catch (originalError) { - throw new StorageError({ - key, - projectId: $projectId.get(), - originalError, - }); - } - }, - setItem: async (key, value) => { - try { - persistRefCount++; - - if (lastPersistedState.get(key) === value) { - log.trace(`Skipping setting client state for key "${key}" as value is unchanged`); - return value; - } - log.trace(`Setting client state for key "${key}", ${value}`); - await api.setItem(key, value); - lastPersistedState.set(key, value); - return value; - } catch (originalError) { - throw new StorageError({ - key, - value, - projectId: $projectId.get(), - originalError, - }); - } finally { - persistRefCount--; - if (persistRefCount < 0) { - log.trace('Persist ref count is negative, resetting to 0'); - persistRefCount = 0; - } - } - }, - }; - - const clearStorage = async () => { - try { - persistRefCount++; - log.trace('Clearing client state'); - await api.clear(); - } catch { - log.error('Failed to clear client state'); - } finally { - persistRefCount--; - lastPersistedState.clear(); - if (persistRefCount < 0) { - log.trace('Persist ref count is negative, resetting to 0'); - persistRefCount = 0; - } - } - }; - - const registerListeners = () => { - const onBeforeUnload = (e: BeforeUnloadEvent) => { - if (persistRefCount > 0) { - e.preventDefault(); - } - }; - window.addEventListener('beforeunload', onBeforeUnload); - - return () => { - window.removeEventListener('beforeunload', onBeforeUnload); - }; - }; - - return { reduxRememberDriver, clearStorage, registerListeners }; -}; - -export const buildStorageApi = (api?: { - getItem: (key: string) => Promise; - setItem: (key: string, value: any) => Promise; - clear: () => Promise; -}) => { - if (api) { - return buildCustomDriver(api); - } else { - return buildOSSServerBackedDriver(); - } -}; diff --git a/invokeai/frontend/web/src/app/store/store.ts b/invokeai/frontend/web/src/app/store/store.ts index 150056c210..099e2fb32c 100644 --- a/invokeai/frontend/web/src/app/store/store.ts +++ b/invokeai/frontend/web/src/app/store/store.ts @@ -127,9 +127,10 @@ const unserialize: UnserializeFunction = (data, key) => { let state; try { const initialState = getInitialState(); + const parsed = JSON.parse(data); // strip out old keys - const stripped = pick(deepClone(data), keys(initialState)); + const stripped = pick(deepClone(parsed), keys(initialState)); /* * Merge in initial state as default values, covering any missing keys. You might be tempted to use _.defaultsDeep, * but that merges arrays by index and partial objects by key. Using an identity function as the customizer results @@ -141,7 +142,7 @@ const unserialize: UnserializeFunction = (data, key) => { log.debug( { - persistedData: data as JsonObject, + persistedData: parsed as JsonObject, rehydratedData: migrated as JsonObject, diff: diff(data, migrated) as JsonObject, },