feat(ui): initial implementation of model loading

- Update model listing code to use `rtk-query`
- Update all graph generation to use new `pipeline_model_loader` node
This commit is contained in:
psychedelicious
2023-06-22 17:48:57 +10:00
parent 2a178f5a25
commit 339e7ce213
26 changed files with 281 additions and 386 deletions

View File

@@ -24,6 +24,7 @@ import Toaster from './Toaster';
import DeleteImageModal from 'features/gallery/components/DeleteImageModal';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
import { useListModelsQuery } from 'services/apiSlice';
const DEFAULT_CONFIG = {};
@@ -46,6 +47,18 @@ const App = ({
const isApplicationReady = useIsApplicationReady();
const { data: pipelineModels } = useListModelsQuery({
model_type: 'pipeline',
});
const { data: controlnetModels } = useListModelsQuery({
model_type: 'controlnet',
});
const { data: vaeModels } = useListModelsQuery({ model_type: 'vae' });
const { data: loraModels } = useListModelsQuery({ model_type: 'lora' });
const { data: embeddingModels } = useListModelsQuery({
model_type: 'embedding',
});
const [loadingOverridden, setLoadingOverridden] = useState(false);
const dispatch = useAppDispatch();

View File

@@ -5,7 +5,6 @@ import { lightboxPersistDenylist } from 'features/lightbox/store/lightboxPersist
import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist';
import { generationPersistDenylist } from 'features/parameters/store/generationPersistDenylist';
import { postprocessingPersistDenylist } from 'features/parameters/store/postprocessingPersistDenylist';
import { modelsPersistDenylist } from 'features/system/store/modelsPersistDenylist';
import { systemPersistDenylist } from 'features/system/store/systemPersistDenylist';
import { uiPersistDenylist } from 'features/ui/store/uiPersistDenylist';
import { omit } from 'lodash-es';
@@ -18,8 +17,6 @@ const serializationDenylist: {
gallery: galleryPersistDenylist,
generation: generationPersistDenylist,
lightbox: lightboxPersistDenylist,
sd1models: modelsPersistDenylist,
sd2models: modelsPersistDenylist,
nodes: nodesPersistDenylist,
postprocessing: postprocessingPersistDenylist,
system: systemPersistDenylist,

View File

@@ -7,8 +7,6 @@ import { initialNodesState } from 'features/nodes/store/nodesSlice';
import { initialGenerationState } from 'features/parameters/store/generationSlice';
import { initialPostprocessingState } from 'features/parameters/store/postprocessingSlice';
import { initialConfigState } from 'features/system/store/configSlice';
import { sd1InitialPipelineModelsState } from 'features/system/store/models/sd1PipelineModelSlice';
import { sd2InitialPipelineModelsState } from 'features/system/store/models/sd2PipelineModelSlice';
import { initialSystemState } from 'features/system/store/systemSlice';
import { initialHotkeysState } from 'features/ui/store/hotkeysSlice';
import { initialUIState } from 'features/ui/store/uiSlice';
@@ -22,8 +20,6 @@ const initialStates: {
gallery: initialGalleryState,
generation: initialGenerationState,
lightbox: initialLightboxState,
sd1PipelineModels: sd1InitialPipelineModelsState,
sd2PipelineModels: sd2InitialPipelineModelsState,
nodes: initialNodesState,
postprocessing: initialPostprocessingState,
system: initialSystemState,

View File

@@ -1,7 +1,6 @@
import { log } from 'app/logging/useLogger';
import { appSocketConnected, socketConnected } from 'services/events/actions';
import { receivedPageOfImages } from 'services/thunks/image';
import { receivedModels } from 'services/thunks/model';
import { receivedOpenAPISchema } from 'services/thunks/schema';
import { startAppListening } from '../..';
@@ -15,8 +14,7 @@ export const addSocketConnectedEventListener = () => {
moduleLog.debug({ timestamp }, 'Connected');
const { sd1pipelinemodels, sd2pipelinemodels, nodes, config, images } =
getState();
const { nodes, config, images } = getState();
const { disabledTabs } = config;
@@ -29,14 +27,6 @@ export const addSocketConnectedEventListener = () => {
);
}
if (!sd1pipelinemodels.ids.length) {
dispatch(receivedModels({ baseModel: 'sd-1', modelType: 'pipeline' }));
}
if (!sd2pipelinemodels.ids.length) {
dispatch(receivedModels({ baseModel: 'sd-2', modelType: 'pipeline' }));
}
if (!nodes.schema && !disabledTabs.includes('nodes')) {
dispatch(receivedOpenAPISchema());
}

View File

@@ -28,11 +28,6 @@ import { listenerMiddleware } from './middleware/listenerMiddleware';
import { actionSanitizer } from './middleware/devtools/actionSanitizer';
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
// Model Reducers
import sd1PipelineModelReducer from 'features/system/store/models/sd1PipelineModelSlice';
import sd2PipelineModelReducer from 'features/system/store/models/sd2PipelineModelSlice';
import { LOCALSTORAGE_PREFIX } from './constants';
import { serialize } from './enhancers/reduxRemember/serialize';
import { unserialize } from './enhancers/reduxRemember/unserialize';
@@ -43,8 +38,6 @@ const allReducers = {
gallery: galleryReducer,
generation: generationReducer,
lightbox: lightboxReducer,
sd1pipelinemodels: sd1PipelineModelReducer,
sd2pipelinemodels: sd2PipelineModelReducer,
nodes: nodesReducer,
postprocessing: postprocessingReducer,
system: systemReducer,
@@ -54,8 +47,8 @@ const allReducers = {
images: imagesReducer,
controlNet: controlNetReducer,
boards: boardsReducer,
[api.reducerPath]: api.reducer,
// session: sessionReducer,
[api.reducerPath]: api.reducer,
};
const rootReducer = combineReducers(allReducers);