feat(ui): simpler storage driver impl

This commit is contained in:
psychedelicious
2025-07-29 19:47:54 +10:00
parent a44de079dd
commit 7a23d8266f
3 changed files with 66 additions and 148 deletions

View File

@@ -6,7 +6,7 @@ import type { StudioInitAction } from 'app/hooks/useStudioInitAction';
import { $didStudioInit } from 'app/hooks/useStudioInitAction';
import type { LoggingOverrides } from 'app/logging/logger';
import { $loggingOverrides, configureLogging } from 'app/logging/logger';
import { buildStorageApi } from 'app/store/enhancers/reduxRemember/driver';
import { buildStorage } from 'app/store/enhancers/reduxRemember/driver';
import { $accountSettingsLink } from 'app/store/nanostores/accountSettingsLink';
import { $authToken } from 'app/store/nanostores/authToken';
import { $baseUrl } from 'app/store/nanostores/baseUrl';
@@ -319,7 +319,7 @@ const InvokeAIUI = ({
};
}, [isDebugging]);
const storage = useMemo(() => buildStorageApi(storageConfig), [storageConfig]);
const storage = useMemo(() => buildStorage(storageConfig), [storageConfig]);
useEffect(() => {
const storageCleanup = storage.registerListeners();

View File

@@ -1,5 +1,3 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import { logger } from 'app/logging/logger';
import { StorageError } from 'app/store/enhancers/reduxRemember/errors';
import { $projectId } from 'app/store/nanostores/projectId';
@@ -9,11 +7,57 @@ import { buildAppInfoUrl } from 'services/api/endpoints/appInfo';
const log = logger('system');
const buildOSSServerBackedDriver = (): {
const getUrl = (key?: string) => {
const baseUrl = getBaseUrl();
const query: Record<string, string> = {};
if (key) {
query['key'] = key;
}
const path = buildAppInfoUrl('client_state', query);
const url = `${baseUrl}/${path}`;
return url;
};
const defaultGetItem = async (key: string): Promise<string | undefined> => {
const url = getUrl(key);
const res = await fetch(url, { method: 'GET' });
if (!res.ok) {
throw new Error(`Response status: ${res.status}`);
}
return res.json();
};
const defaultSetItem = async (key: string, value: string): Promise<string> => {
const url = getUrl(key);
const res = await fetch(url, { method: 'POST', body: value });
if (!res.ok) {
throw new Error(`Response status: ${res.status}`);
}
return res.json();
};
const defaultClear = async (): Promise<void> => {
const url = getUrl();
const res = await fetch(url, { method: 'DELETE' });
if (!res.ok) {
throw new Error(`Response status: ${res.status}`);
}
};
export const buildStorage = (api?: {
getItem: (key: string) => Promise<string | undefined>;
setItem: (key: string, value: string) => Promise<string>;
clear: () => Promise<void>;
}): {
reduxRememberDriver: ReduxRememberDriver;
clearStorage: () => Promise<void>;
registerListeners: () => () => void;
} => {
const _api = api ?? {
getItem: defaultGetItem,
setItem: defaultSetItem,
clear: defaultClear,
};
// Persistence happens per slice. To track when persistence is in progress, maintain a ref count, incrementing
// it when a slice is being persisted and decrementing it when the persistence is done.
let persistRefCount = 0;
@@ -35,32 +79,15 @@ const buildOSSServerBackedDriver = (): {
//
// To avoid unnecessary network requests, we keep track of the last persisted state for each key. If the value to
// be persisted is the same as the last persisted value, we can skip the network request.
const lastPersistedState = new Map<string, unknown>();
const getUrl = (key?: string) => {
const baseUrl = getBaseUrl();
const query: Record<string, string> = {};
if (key) {
query['key'] = key;
}
const path = buildAppInfoUrl('client_state', query);
const url = `${baseUrl}/${path}`;
return url;
};
const lastPersistedState = new Map<string, string | undefined>();
const reduxRememberDriver: ReduxRememberDriver = {
getItem: async (key) => {
try {
const url = getUrl(key);
const res = await fetch(url, { method: 'GET' });
if (!res.ok) {
throw new Error(`Response status: ${res.status}`);
}
const text = await res.text();
if (!lastPersistedState.get(key)) {
lastPersistedState.set(key, text);
}
return JSON.parse(text);
const value = await _api.getItem(key);
lastPersistedState.set(key, value);
log.trace({ key, last: lastPersistedState.get(key), next: value }, `Getting state for ${key}`);
return value;
} catch (originalError) {
throw new StorageError({
key,
@@ -73,20 +100,16 @@ const buildOSSServerBackedDriver = (): {
try {
persistRefCount++;
if (lastPersistedState.get(key) === value) {
log.trace(`Skipping persist for key "${key}" as value is unchanged.`);
log.trace(
{ key, last: lastPersistedState.get(key), next: value },
`Skipping persist for ${key} as value is unchanged`
);
return value;
}
const url = getUrl(key);
const headers = new Headers({
'Content-Type': 'application/json',
});
const res = await fetch(url, { method: 'POST', headers, body: value });
if (!res.ok) {
throw new Error(`Response status: ${res.status}`);
}
lastPersistedState.set(key, value);
return value;
log.trace({ key, last: lastPersistedState.get(key), next: value }, `Persisting state for ${key}`);
const resultValue = await _api.setItem(key, value);
lastPersistedState.set(key, resultValue);
return resultValue;
} catch (originalError) {
throw new StorageError({
key,
@@ -107,11 +130,7 @@ const buildOSSServerBackedDriver = (): {
const clearStorage = async () => {
try {
persistRefCount++;
const url = getUrl();
const res = await fetch(url, { method: 'DELETE' });
if (!res.ok) {
throw new Error(`Response status: ${res.status}`);
}
await _api.clear();
} catch {
log.error('Failed to reset client state');
} finally {
@@ -139,105 +158,3 @@ const buildOSSServerBackedDriver = (): {
return { reduxRememberDriver, clearStorage, registerListeners };
};
const buildCustomDriver = (api: {
getItem: (key: string) => Promise<any>;
setItem: (key: string, value: any) => Promise<any>;
clear: () => Promise<void>;
}): {
reduxRememberDriver: ReduxRememberDriver;
clearStorage: () => Promise<void>;
registerListeners: () => () => void;
} => {
// See the comment in `buildOSSServerBackedDriver` for an explanation of this variable.
let persistRefCount = 0;
// See the comment in `buildOSSServerBackedDriver` for an explanation of this variable.
const lastPersistedState = new Map<string, unknown>();
const reduxRememberDriver: ReduxRememberDriver = {
getItem: async (key) => {
try {
log.trace(`Getting client state for key "${key}"`);
return await api.getItem(key);
} catch (originalError) {
throw new StorageError({
key,
projectId: $projectId.get(),
originalError,
});
}
},
setItem: async (key, value) => {
try {
persistRefCount++;
if (lastPersistedState.get(key) === value) {
log.trace(`Skipping setting client state for key "${key}" as value is unchanged`);
return value;
}
log.trace(`Setting client state for key "${key}", ${value}`);
await api.setItem(key, value);
lastPersistedState.set(key, value);
return value;
} catch (originalError) {
throw new StorageError({
key,
value,
projectId: $projectId.get(),
originalError,
});
} finally {
persistRefCount--;
if (persistRefCount < 0) {
log.trace('Persist ref count is negative, resetting to 0');
persistRefCount = 0;
}
}
},
};
const clearStorage = async () => {
try {
persistRefCount++;
log.trace('Clearing client state');
await api.clear();
} catch {
log.error('Failed to clear client state');
} finally {
persistRefCount--;
lastPersistedState.clear();
if (persistRefCount < 0) {
log.trace('Persist ref count is negative, resetting to 0');
persistRefCount = 0;
}
}
};
const registerListeners = () => {
const onBeforeUnload = (e: BeforeUnloadEvent) => {
if (persistRefCount > 0) {
e.preventDefault();
}
};
window.addEventListener('beforeunload', onBeforeUnload);
return () => {
window.removeEventListener('beforeunload', onBeforeUnload);
};
};
return { reduxRememberDriver, clearStorage, registerListeners };
};
export const buildStorageApi = (api?: {
getItem: (key: string) => Promise<any>;
setItem: (key: string, value: any) => Promise<any>;
clear: () => Promise<void>;
}) => {
if (api) {
return buildCustomDriver(api);
} else {
return buildOSSServerBackedDriver();
}
};

View File

@@ -127,9 +127,10 @@ const unserialize: UnserializeFunction = (data, key) => {
let state;
try {
const initialState = getInitialState();
const parsed = JSON.parse(data);
// strip out old keys
const stripped = pick(deepClone(data), keys(initialState));
const stripped = pick(deepClone(parsed), keys(initialState));
/*
* Merge in initial state as default values, covering any missing keys. You might be tempted to use _.defaultsDeep,
* but that merges arrays by index and partial objects by key. Using an identity function as the customizer results
@@ -141,7 +142,7 @@ const unserialize: UnserializeFunction = (data, key) => {
log.debug(
{
persistedData: data as JsonObject,
persistedData: parsed as JsonObject,
rehydratedData: migrated as JsonObject,
diff: diff(data, migrated) as JsonObject,
},