mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-13 04:45:08 -05:00
feat(ui): simpler storage driver impl
This commit is contained in:
@@ -6,7 +6,7 @@ import type { StudioInitAction } from 'app/hooks/useStudioInitAction';
|
||||
import { $didStudioInit } from 'app/hooks/useStudioInitAction';
|
||||
import type { LoggingOverrides } from 'app/logging/logger';
|
||||
import { $loggingOverrides, configureLogging } from 'app/logging/logger';
|
||||
import { buildStorageApi } from 'app/store/enhancers/reduxRemember/driver';
|
||||
import { buildStorage } from 'app/store/enhancers/reduxRemember/driver';
|
||||
import { $accountSettingsLink } from 'app/store/nanostores/accountSettingsLink';
|
||||
import { $authToken } from 'app/store/nanostores/authToken';
|
||||
import { $baseUrl } from 'app/store/nanostores/baseUrl';
|
||||
@@ -319,7 +319,7 @@ const InvokeAIUI = ({
|
||||
};
|
||||
}, [isDebugging]);
|
||||
|
||||
const storage = useMemo(() => buildStorageApi(storageConfig), [storageConfig]);
|
||||
const storage = useMemo(() => buildStorage(storageConfig), [storageConfig]);
|
||||
|
||||
useEffect(() => {
|
||||
const storageCleanup = storage.registerListeners();
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { StorageError } from 'app/store/enhancers/reduxRemember/errors';
|
||||
import { $projectId } from 'app/store/nanostores/projectId';
|
||||
@@ -9,11 +7,57 @@ import { buildAppInfoUrl } from 'services/api/endpoints/appInfo';
|
||||
|
||||
const log = logger('system');
|
||||
|
||||
const buildOSSServerBackedDriver = (): {
|
||||
const getUrl = (key?: string) => {
|
||||
const baseUrl = getBaseUrl();
|
||||
const query: Record<string, string> = {};
|
||||
if (key) {
|
||||
query['key'] = key;
|
||||
}
|
||||
const path = buildAppInfoUrl('client_state', query);
|
||||
const url = `${baseUrl}/${path}`;
|
||||
return url;
|
||||
};
|
||||
|
||||
const defaultGetItem = async (key: string): Promise<string | undefined> => {
|
||||
const url = getUrl(key);
|
||||
const res = await fetch(url, { method: 'GET' });
|
||||
if (!res.ok) {
|
||||
throw new Error(`Response status: ${res.status}`);
|
||||
}
|
||||
return res.json();
|
||||
};
|
||||
|
||||
const defaultSetItem = async (key: string, value: string): Promise<string> => {
|
||||
const url = getUrl(key);
|
||||
const res = await fetch(url, { method: 'POST', body: value });
|
||||
if (!res.ok) {
|
||||
throw new Error(`Response status: ${res.status}`);
|
||||
}
|
||||
return res.json();
|
||||
};
|
||||
|
||||
const defaultClear = async (): Promise<void> => {
|
||||
const url = getUrl();
|
||||
const res = await fetch(url, { method: 'DELETE' });
|
||||
if (!res.ok) {
|
||||
throw new Error(`Response status: ${res.status}`);
|
||||
}
|
||||
};
|
||||
|
||||
export const buildStorage = (api?: {
|
||||
getItem: (key: string) => Promise<string | undefined>;
|
||||
setItem: (key: string, value: string) => Promise<string>;
|
||||
clear: () => Promise<void>;
|
||||
}): {
|
||||
reduxRememberDriver: ReduxRememberDriver;
|
||||
clearStorage: () => Promise<void>;
|
||||
registerListeners: () => () => void;
|
||||
} => {
|
||||
const _api = api ?? {
|
||||
getItem: defaultGetItem,
|
||||
setItem: defaultSetItem,
|
||||
clear: defaultClear,
|
||||
};
|
||||
// Persistence happens per slice. To track when persistence is in progress, maintain a ref count, incrementing
|
||||
// it when a slice is being persisted and decrementing it when the persistence is done.
|
||||
let persistRefCount = 0;
|
||||
@@ -35,32 +79,15 @@ const buildOSSServerBackedDriver = (): {
|
||||
//
|
||||
// To avoid unnecessary network requests, we keep track of the last persisted state for each key. If the value to
|
||||
// be persisted is the same as the last persisted value, we can skip the network request.
|
||||
const lastPersistedState = new Map<string, unknown>();
|
||||
|
||||
const getUrl = (key?: string) => {
|
||||
const baseUrl = getBaseUrl();
|
||||
const query: Record<string, string> = {};
|
||||
if (key) {
|
||||
query['key'] = key;
|
||||
}
|
||||
const path = buildAppInfoUrl('client_state', query);
|
||||
const url = `${baseUrl}/${path}`;
|
||||
return url;
|
||||
};
|
||||
const lastPersistedState = new Map<string, string | undefined>();
|
||||
|
||||
const reduxRememberDriver: ReduxRememberDriver = {
|
||||
getItem: async (key) => {
|
||||
try {
|
||||
const url = getUrl(key);
|
||||
const res = await fetch(url, { method: 'GET' });
|
||||
if (!res.ok) {
|
||||
throw new Error(`Response status: ${res.status}`);
|
||||
}
|
||||
const text = await res.text();
|
||||
if (!lastPersistedState.get(key)) {
|
||||
lastPersistedState.set(key, text);
|
||||
}
|
||||
return JSON.parse(text);
|
||||
const value = await _api.getItem(key);
|
||||
lastPersistedState.set(key, value);
|
||||
log.trace({ key, last: lastPersistedState.get(key), next: value }, `Getting state for ${key}`);
|
||||
return value;
|
||||
} catch (originalError) {
|
||||
throw new StorageError({
|
||||
key,
|
||||
@@ -73,20 +100,16 @@ const buildOSSServerBackedDriver = (): {
|
||||
try {
|
||||
persistRefCount++;
|
||||
if (lastPersistedState.get(key) === value) {
|
||||
log.trace(`Skipping persist for key "${key}" as value is unchanged.`);
|
||||
log.trace(
|
||||
{ key, last: lastPersistedState.get(key), next: value },
|
||||
`Skipping persist for ${key} as value is unchanged`
|
||||
);
|
||||
return value;
|
||||
}
|
||||
const url = getUrl(key);
|
||||
const headers = new Headers({
|
||||
'Content-Type': 'application/json',
|
||||
});
|
||||
const res = await fetch(url, { method: 'POST', headers, body: value });
|
||||
if (!res.ok) {
|
||||
throw new Error(`Response status: ${res.status}`);
|
||||
}
|
||||
|
||||
lastPersistedState.set(key, value);
|
||||
return value;
|
||||
log.trace({ key, last: lastPersistedState.get(key), next: value }, `Persisting state for ${key}`);
|
||||
const resultValue = await _api.setItem(key, value);
|
||||
lastPersistedState.set(key, resultValue);
|
||||
return resultValue;
|
||||
} catch (originalError) {
|
||||
throw new StorageError({
|
||||
key,
|
||||
@@ -107,11 +130,7 @@ const buildOSSServerBackedDriver = (): {
|
||||
const clearStorage = async () => {
|
||||
try {
|
||||
persistRefCount++;
|
||||
const url = getUrl();
|
||||
const res = await fetch(url, { method: 'DELETE' });
|
||||
if (!res.ok) {
|
||||
throw new Error(`Response status: ${res.status}`);
|
||||
}
|
||||
await _api.clear();
|
||||
} catch {
|
||||
log.error('Failed to reset client state');
|
||||
} finally {
|
||||
@@ -139,105 +158,3 @@ const buildOSSServerBackedDriver = (): {
|
||||
|
||||
return { reduxRememberDriver, clearStorage, registerListeners };
|
||||
};
|
||||
|
||||
const buildCustomDriver = (api: {
|
||||
getItem: (key: string) => Promise<any>;
|
||||
setItem: (key: string, value: any) => Promise<any>;
|
||||
clear: () => Promise<void>;
|
||||
}): {
|
||||
reduxRememberDriver: ReduxRememberDriver;
|
||||
clearStorage: () => Promise<void>;
|
||||
registerListeners: () => () => void;
|
||||
} => {
|
||||
// See the comment in `buildOSSServerBackedDriver` for an explanation of this variable.
|
||||
let persistRefCount = 0;
|
||||
|
||||
// See the comment in `buildOSSServerBackedDriver` for an explanation of this variable.
|
||||
const lastPersistedState = new Map<string, unknown>();
|
||||
|
||||
const reduxRememberDriver: ReduxRememberDriver = {
|
||||
getItem: async (key) => {
|
||||
try {
|
||||
log.trace(`Getting client state for key "${key}"`);
|
||||
return await api.getItem(key);
|
||||
} catch (originalError) {
|
||||
throw new StorageError({
|
||||
key,
|
||||
projectId: $projectId.get(),
|
||||
originalError,
|
||||
});
|
||||
}
|
||||
},
|
||||
setItem: async (key, value) => {
|
||||
try {
|
||||
persistRefCount++;
|
||||
|
||||
if (lastPersistedState.get(key) === value) {
|
||||
log.trace(`Skipping setting client state for key "${key}" as value is unchanged`);
|
||||
return value;
|
||||
}
|
||||
log.trace(`Setting client state for key "${key}", ${value}`);
|
||||
await api.setItem(key, value);
|
||||
lastPersistedState.set(key, value);
|
||||
return value;
|
||||
} catch (originalError) {
|
||||
throw new StorageError({
|
||||
key,
|
||||
value,
|
||||
projectId: $projectId.get(),
|
||||
originalError,
|
||||
});
|
||||
} finally {
|
||||
persistRefCount--;
|
||||
if (persistRefCount < 0) {
|
||||
log.trace('Persist ref count is negative, resetting to 0');
|
||||
persistRefCount = 0;
|
||||
}
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
const clearStorage = async () => {
|
||||
try {
|
||||
persistRefCount++;
|
||||
log.trace('Clearing client state');
|
||||
await api.clear();
|
||||
} catch {
|
||||
log.error('Failed to clear client state');
|
||||
} finally {
|
||||
persistRefCount--;
|
||||
lastPersistedState.clear();
|
||||
if (persistRefCount < 0) {
|
||||
log.trace('Persist ref count is negative, resetting to 0');
|
||||
persistRefCount = 0;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const registerListeners = () => {
|
||||
const onBeforeUnload = (e: BeforeUnloadEvent) => {
|
||||
if (persistRefCount > 0) {
|
||||
e.preventDefault();
|
||||
}
|
||||
};
|
||||
window.addEventListener('beforeunload', onBeforeUnload);
|
||||
|
||||
return () => {
|
||||
window.removeEventListener('beforeunload', onBeforeUnload);
|
||||
};
|
||||
};
|
||||
|
||||
return { reduxRememberDriver, clearStorage, registerListeners };
|
||||
};
|
||||
|
||||
export const buildStorageApi = (api?: {
|
||||
getItem: (key: string) => Promise<any>;
|
||||
setItem: (key: string, value: any) => Promise<any>;
|
||||
clear: () => Promise<void>;
|
||||
}) => {
|
||||
if (api) {
|
||||
return buildCustomDriver(api);
|
||||
} else {
|
||||
return buildOSSServerBackedDriver();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -127,9 +127,10 @@ const unserialize: UnserializeFunction = (data, key) => {
|
||||
let state;
|
||||
try {
|
||||
const initialState = getInitialState();
|
||||
const parsed = JSON.parse(data);
|
||||
|
||||
// strip out old keys
|
||||
const stripped = pick(deepClone(data), keys(initialState));
|
||||
const stripped = pick(deepClone(parsed), keys(initialState));
|
||||
/*
|
||||
* Merge in initial state as default values, covering any missing keys. You might be tempted to use _.defaultsDeep,
|
||||
* but that merges arrays by index and partial objects by key. Using an identity function as the customizer results
|
||||
@@ -141,7 +142,7 @@ const unserialize: UnserializeFunction = (data, key) => {
|
||||
|
||||
log.debug(
|
||||
{
|
||||
persistedData: data as JsonObject,
|
||||
persistedData: parsed as JsonObject,
|
||||
rehydratedData: migrated as JsonObject,
|
||||
diff: diff(data, migrated) as JsonObject,
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user