From da0efeaa7f485f67dc344dd6aea1f4dd05e50045 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 10 Aug 2023 15:20:37 +1000 Subject: [PATCH] fix(ui): fix canvas model switching There was no check at all to see if the canvas had a valid model already selected. The first model in the list was selected every time. Now, we check if its valid. If not, we go through the logic to try and pick the first valid model. If there are no valid models, or there was a problem listing models, the model selection is cleared. --- .../listeners/tabChanged.ts | 63 ++++++++++--------- 1 file changed, 33 insertions(+), 30 deletions(-) diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/tabChanged.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/tabChanged.ts index 0569827859..6d3e599ae2 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/tabChanged.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/tabChanged.ts @@ -1,55 +1,58 @@ import { modelChanged } from 'features/parameters/store/generationSlice'; import { setActiveTab } from 'features/ui/store/uiSlice'; -import { forEach } from 'lodash-es'; import { NON_REFINER_BASE_MODELS } from 'services/api/constants'; -import { - MainModelConfigEntity, - modelsApi, -} from 'services/api/endpoints/models'; +import { mainModelsAdapter, modelsApi } from 'services/api/endpoints/models'; import { startAppListening } from '..'; export const addTabChangedListener = () => { startAppListening({ actionCreator: setActiveTab, - effect: (action, { getState, dispatch }) => { + effect: async (action, { getState, dispatch }) => { const activeTabName = action.payload; if (activeTabName === 'unifiedCanvas') { - // grab the models from RTK Query cache - const { data } = modelsApi.endpoints.getMainModels.select( - NON_REFINER_BASE_MODELS - )(getState()); + const currentBaseModel = getState().generation.model?.base_model; - if (!data) { - // no models yet, so we can't do anything - dispatch(modelChanged(null)); + if (currentBaseModel && ['sd-1', 'sd-2'].includes(currentBaseModel)) { + // if we're already on a valid model, no change needed return; } - // need to filter out all the invalid canvas models (currently, this is just sdxl) - const validCanvasModels: MainModelConfigEntity[] = []; + try { + // just grab fresh models + const modelsRequest = dispatch( + modelsApi.endpoints.getMainModels.initiate(NON_REFINER_BASE_MODELS) + ); + const models = await modelsRequest.unwrap(); + // cancel this cache subscription + modelsRequest.unsubscribe(); - forEach(data.entities, (entity) => { - if (!entity) { + if (!models.ids.length) { + // no valid canvas models + dispatch(modelChanged(null)); return; } - if (['sd-1', 'sd-2'].includes(entity.base_model)) { - validCanvasModels.push(entity); + + // need to filter out all the invalid canvas models (currently sdxl & refiner) + const validCanvasModels = mainModelsAdapter + .getSelectors() + .selectAll(models) + .filter((model) => ['sd-1', 'sd-2'].includes(model.base_model)); + + const firstValidCanvasModel = validCanvasModels[0]; + + if (!firstValidCanvasModel) { + // no valid canvas models + dispatch(modelChanged(null)); + return; } - }); - // this could still be undefined even tho TS doesn't say so - const firstValidCanvasModel = validCanvasModels[0]; + const { base_model, model_name, model_type } = firstValidCanvasModel; - if (!firstValidCanvasModel) { - // uh oh, we have no models that are valid for canvas + dispatch(modelChanged({ base_model, model_name, model_type })); + } catch { + // network request failed, bail dispatch(modelChanged(null)); - return; } - - // only store the model name and base model in redux - const { base_model, model_name, model_type } = firstValidCanvasModel; - - dispatch(modelChanged({ base_model, model_name, model_type })); } }, });