From f002bca2fa057aa9dcb997fc3520a3b31a12578a Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 17 Jun 2024 10:07:10 +1000 Subject: [PATCH] feat(ui): handle new `model_install_download_started` event When a model install is initiated from outside the client, we now trigger the model manager tab's model install list to update. - Handle new `model_install_download_started` event - Handle `model_install_download_complete` event (this event is not new but was never handled) - Update optimistic updates/cache invalidation logic to efficiently update the model install list --- .../listeners/socketio/socketModelInstall.ts | 184 +++++++++++++----- 1 file changed, 136 insertions(+), 48 deletions(-) diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall.ts index 113d2cbd66..22ad87fbe9 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall.ts @@ -5,6 +5,8 @@ import { socketModelInstallCancelled, socketModelInstallComplete, socketModelInstallDownloadProgress, + socketModelInstallDownloadsComplete, + socketModelInstallDownloadStarted, socketModelInstallError, socketModelInstallStarted, } from 'services/events/actions'; @@ -14,9 +16,12 @@ import { * which is a bit misleading. For example, a `model_install_started` event is actually fired _after_ the model has fully * downloaded and is being "physically" installed. * + * Note: the download events are only fired for remote model installs, not local. + * * Here's the expected flow: - * - Model manager does some prep - * - `model_install_download_progress` fired when the download starts and continually until the download is complete + * - API receives install request, model manager preps the install + * - `model_install_download_started` fired when the download starts + * - `model_install_download_progress` fired continually until the download is complete * - `model_install_download_complete` fired when the download is complete * - `model_install_started` fired when the "physical" installation starts * - `model_install_complete` fired when the installation is complete @@ -24,47 +29,98 @@ import { * - `model_install_error` fired if the installation has an error */ +const selectModelInstalls = modelsApi.endpoints.listModelInstalls.select(); + export const addModelInstallEventListener = (startAppListening: AppStartListening) => { + startAppListening({ + actionCreator: socketModelInstallDownloadStarted, + effect: async (action, { dispatch, getState }) => { + const { id } = action.payload.data; + const { data } = selectModelInstalls(getState()); + + if (!data || !data.find((m) => m.id === id)) { + dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }])); + } else { + dispatch( + modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { + const modelImport = draft.find((m) => m.id === id); + if (modelImport) { + modelImport.status = 'downloading'; + } + return draft; + }) + ); + } + }, + }); + startAppListening({ actionCreator: socketModelInstallStarted, - effect: async (action, { dispatch }) => { - dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }])); + effect: async (action, { dispatch, getState }) => { + const { id } = action.payload.data; + const { data } = selectModelInstalls(getState()); + + if (!data || !data.find((m) => m.id === id)) { + dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }])); + } else { + dispatch( + modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { + const modelImport = draft.find((m) => m.id === id); + if (modelImport) { + modelImport.status = 'running'; + } + return draft; + }) + ); + } }, }); startAppListening({ actionCreator: socketModelInstallDownloadProgress, - effect: async (action, { dispatch }) => { + effect: async (action, { dispatch, getState }) => { const { bytes, total_bytes, id } = action.payload.data; + const { data } = selectModelInstalls(getState()); - dispatch( - modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { - const modelImport = draft.find((m) => m.id === id); - if (modelImport) { - modelImport.bytes = bytes; - modelImport.total_bytes = total_bytes; - modelImport.status = 'downloading'; - } - return draft; - }) - ); + if (!data || !data.find((m) => m.id === id)) { + dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }])); + } else { + dispatch( + modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { + const modelImport = draft.find((m) => m.id === id); + if (modelImport) { + modelImport.bytes = bytes; + modelImport.total_bytes = total_bytes; + modelImport.status = 'downloading'; + } + return draft; + }) + ); + } }, }); startAppListening({ actionCreator: socketModelInstallComplete, - effect: (action, { dispatch }) => { + effect: (action, { dispatch, getState }) => { const { id } = action.payload.data; - dispatch( - modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { - const modelImport = draft.find((m) => m.id === id); - if (modelImport) { - modelImport.status = 'completed'; - } - return draft; - }) - ); + const { data } = selectModelInstalls(getState()); + + if (!data || !data.find((m) => m.id === id)) { + dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }])); + } else { + dispatch( + modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { + const modelImport = draft.find((m) => m.id === id); + if (modelImport) { + modelImport.status = 'completed'; + } + return draft; + }) + ); + } + dispatch(api.util.invalidateTags([{ type: 'ModelConfig', id: LIST_TAG }])); dispatch(api.util.invalidateTags([{ type: 'ModelScanFolderResults', id: LIST_TAG }])); }, @@ -72,37 +128,69 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin startAppListening({ actionCreator: socketModelInstallError, - effect: (action, { dispatch }) => { + effect: (action, { dispatch, getState }) => { const { id, error, error_type } = action.payload.data; + const { data } = selectModelInstalls(getState()); - dispatch( - modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { - const modelImport = draft.find((m) => m.id === id); - if (modelImport) { - modelImport.status = 'error'; - modelImport.error_reason = error_type; - modelImport.error = error; - } - return draft; - }) - ); + if (!data || !data.find((m) => m.id === id)) { + dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }])); + } else { + dispatch( + modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { + const modelImport = draft.find((m) => m.id === id); + if (modelImport) { + modelImport.status = 'error'; + modelImport.error_reason = error_type; + modelImport.error = error; + } + return draft; + }) + ); + } }, }); startAppListening({ actionCreator: socketModelInstallCancelled, - effect: (action, { dispatch }) => { + effect: (action, { dispatch, getState }) => { const { id } = action.payload.data; + const { data } = selectModelInstalls(getState()); - dispatch( - modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { - const modelImport = draft.find((m) => m.id === id); - if (modelImport) { - modelImport.status = 'cancelled'; - } - return draft; - }) - ); + if (!data || !data.find((m) => m.id === id)) { + dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }])); + } else { + dispatch( + modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { + const modelImport = draft.find((m) => m.id === id); + if (modelImport) { + modelImport.status = 'cancelled'; + } + return draft; + }) + ); + } + }, + }); + + startAppListening({ + actionCreator: socketModelInstallDownloadsComplete, + effect: (action, { dispatch, getState }) => { + const { id } = action.payload.data; + const { data } = selectModelInstalls(getState()); + + if (!data || !data.find((m) => m.id === id)) { + dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }])); + } else { + dispatch( + modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { + const modelImport = draft.find((m) => m.id === id); + if (modelImport) { + modelImport.status = 'downloads_done'; + } + return draft; + }) + ); + } }, }); };