Files
InvokeAI/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts

350 lines
11 KiB
TypeScript

import { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import { ImageDTO } from 'services/api/types';
import {
ControlNetProcessorType,
RequiredCannyImageProcessorInvocation,
RequiredControlNetProcessorNode,
} from './types';
import {
CONTROLNET_MODEL_DEFAULT_PROCESSORS,
// CONTROLNET_MODELS,
CONTROLNET_PROCESSORS,
// ControlNetModelName,
} from './constants';
import { controlNetImageProcessed } from './actions';
import { imageDeleted, imageUrlsReceived } from 'services/api/thunks/image';
import { forEach } from 'lodash-es';
import { isAnySessionRejected } from 'services/api/thunks/session';
import { appSocketInvocationError } from 'services/events/actions';
export type ControlModes =
| 'balanced'
| 'more_prompt'
| 'more_control'
| 'unbalanced';
export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
isEnabled: true,
model: '',
weight: 1,
beginStepPct: 0,
endStepPct: 1,
controlMode: 'balanced',
controlImage: null,
processedControlImage: null,
processorType: 'canny_image_processor',
processorNode: CONTROLNET_PROCESSORS.canny_image_processor
.default as RequiredCannyImageProcessorInvocation,
shouldAutoConfig: true,
};
export type ControlNetConfig = {
controlNetId: string;
isEnabled: boolean;
model: string;
weight: number;
beginStepPct: number;
endStepPct: number;
controlMode: ControlModes;
controlImage: string | null;
processedControlImage: string | null;
processorType: ControlNetProcessorType;
processorNode: RequiredControlNetProcessorNode;
shouldAutoConfig: boolean;
};
export type ControlNetState = {
controlNets: Record<string, ControlNetConfig>;
isEnabled: boolean;
pendingControlImages: string[];
};
export const initialControlNetState: ControlNetState = {
controlNets: {},
isEnabled: false,
pendingControlImages: [],
};
export const controlNetSlice = createSlice({
name: 'controlNet',
initialState: initialControlNetState,
reducers: {
isControlNetEnabledToggled: (state) => {
state.isEnabled = !state.isEnabled;
},
controlNetAdded: (
state,
action: PayloadAction<{
controlNetId: string;
controlNet?: ControlNetConfig;
}>
) => {
const { controlNetId, controlNet } = action.payload;
state.controlNets[controlNetId] = {
...(controlNet ?? initialControlNet),
controlNetId,
};
},
controlNetAddedFromImage: (
state,
action: PayloadAction<{ controlNetId: string; controlImage: string }>
) => {
const { controlNetId, controlImage } = action.payload;
state.controlNets[controlNetId] = {
...initialControlNet,
controlNetId,
controlImage,
};
},
controlNetRemoved: (
state,
action: PayloadAction<{ controlNetId: string }>
) => {
const { controlNetId } = action.payload;
delete state.controlNets[controlNetId];
},
controlNetToggled: (
state,
action: PayloadAction<{ controlNetId: string }>
) => {
const { controlNetId } = action.payload;
state.controlNets[controlNetId].isEnabled =
!state.controlNets[controlNetId].isEnabled;
},
controlNetImageChanged: (
state,
action: PayloadAction<{
controlNetId: string;
controlImage: string | null;
}>
) => {
const { controlNetId, controlImage } = action.payload;
state.controlNets[controlNetId].controlImage = controlImage;
state.controlNets[controlNetId].processedControlImage = null;
if (
controlImage !== null &&
state.controlNets[controlNetId].processorType !== 'none'
) {
state.pendingControlImages.push(controlNetId);
}
},
controlNetProcessedImageChanged: (
state,
action: PayloadAction<{
controlNetId: string;
processedControlImage: string | null;
}>
) => {
const { controlNetId, processedControlImage } = action.payload;
state.controlNets[controlNetId].processedControlImage =
processedControlImage;
state.pendingControlImages = state.pendingControlImages.filter(
(id) => id !== controlNetId
);
},
controlNetModelChanged: (
state,
action: PayloadAction<{
controlNetId: string;
model: string;
}>
) => {
const { controlNetId, model } = action.payload;
state.controlNets[controlNetId].model = model;
state.controlNets[controlNetId].processedControlImage = null;
if (state.controlNets[controlNetId].shouldAutoConfig) {
let processorType: ControlNetProcessorType | undefined = undefined;
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
if (model.includes(modelSubstring)) {
processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring];
break;
}
}
if (processorType) {
state.controlNets[controlNetId].processorType = processorType;
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
processorType
].default as RequiredControlNetProcessorNode;
} else {
state.controlNets[controlNetId].processorType = 'none';
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS
.none.default as RequiredControlNetProcessorNode;
}
}
},
controlNetWeightChanged: (
state,
action: PayloadAction<{ controlNetId: string; weight: number }>
) => {
const { controlNetId, weight } = action.payload;
state.controlNets[controlNetId].weight = weight;
},
controlNetBeginStepPctChanged: (
state,
action: PayloadAction<{ controlNetId: string; beginStepPct: number }>
) => {
const { controlNetId, beginStepPct } = action.payload;
state.controlNets[controlNetId].beginStepPct = beginStepPct;
},
controlNetEndStepPctChanged: (
state,
action: PayloadAction<{ controlNetId: string; endStepPct: number }>
) => {
const { controlNetId, endStepPct } = action.payload;
state.controlNets[controlNetId].endStepPct = endStepPct;
},
controlNetControlModeChanged: (
state,
action: PayloadAction<{ controlNetId: string; controlMode: ControlModes }>
) => {
const { controlNetId, controlMode } = action.payload;
state.controlNets[controlNetId].controlMode = controlMode;
},
controlNetProcessorParamsChanged: (
state,
action: PayloadAction<{
controlNetId: string;
changes: Omit<
Partial<RequiredControlNetProcessorNode>,
'id' | 'type' | 'is_intermediate'
>;
}>
) => {
const { controlNetId, changes } = action.payload;
const processorNode = state.controlNets[controlNetId].processorNode;
state.controlNets[controlNetId].processorNode = {
...processorNode,
...changes,
};
state.controlNets[controlNetId].shouldAutoConfig = false;
},
controlNetProcessorTypeChanged: (
state,
action: PayloadAction<{
controlNetId: string;
processorType: ControlNetProcessorType;
}>
) => {
const { controlNetId, processorType } = action.payload;
state.controlNets[controlNetId].processedControlImage = null;
state.controlNets[controlNetId].processorType = processorType;
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
processorType
].default as RequiredControlNetProcessorNode;
state.controlNets[controlNetId].shouldAutoConfig = false;
},
controlNetAutoConfigToggled: (
state,
action: PayloadAction<{
controlNetId: string;
}>
) => {
const { controlNetId } = action.payload;
const newShouldAutoConfig =
!state.controlNets[controlNetId].shouldAutoConfig;
if (newShouldAutoConfig) {
// manage the processor for the user
let processorType: ControlNetProcessorType | undefined = undefined;
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
if (state.controlNets[controlNetId].model.includes(modelSubstring)) {
processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring];
break;
}
}
if (processorType) {
state.controlNets[controlNetId].processorType = processorType;
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
processorType
].default as RequiredControlNetProcessorNode;
} else {
state.controlNets[controlNetId].processorType = 'none';
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS
.none.default as RequiredControlNetProcessorNode;
}
}
state.controlNets[controlNetId].shouldAutoConfig = newShouldAutoConfig;
},
controlNetReset: () => {
return { ...initialControlNetState };
},
},
extraReducers: (builder) => {
builder.addCase(controlNetImageProcessed, (state, action) => {
if (
state.controlNets[action.payload.controlNetId].controlImage !== null
) {
state.pendingControlImages.push(action.payload.controlNetId);
}
});
builder.addCase(imageDeleted.pending, (state, action) => {
// Preemptively remove the image from the gallery
const { image_name } = action.meta.arg;
forEach(state.controlNets, (c) => {
if (c.controlImage === image_name) {
c.controlImage = null;
c.processedControlImage = null;
}
if (c.processedControlImage === image_name) {
c.processedControlImage = null;
}
});
});
// builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
// const { image_name, image_url, thumbnail_url } = action.payload;
// forEach(state.controlNets, (c) => {
// if (c.controlImage?.image_name === image_name) {
// c.controlImage.image_url = image_url;
// c.controlImage.thumbnail_url = thumbnail_url;
// }
// if (c.processedControlImage?.image_name === image_name) {
// c.processedControlImage.image_url = image_url;
// c.processedControlImage.thumbnail_url = thumbnail_url;
// }
// });
// });
builder.addCase(appSocketInvocationError, (state, action) => {
state.pendingControlImages = [];
});
builder.addMatcher(isAnySessionRejected, (state, action) => {
state.pendingControlImages = [];
});
},
});
export const {
isControlNetEnabledToggled,
controlNetAdded,
controlNetAddedFromImage,
controlNetRemoved,
controlNetImageChanged,
controlNetProcessedImageChanged,
controlNetToggled,
controlNetModelChanged,
controlNetWeightChanged,
controlNetBeginStepPctChanged,
controlNetEndStepPctChanged,
controlNetControlModeChanged,
controlNetProcessorParamsChanged,
controlNetProcessorTypeChanged,
controlNetReset,
controlNetAutoConfigToggled,
} = controlNetSlice.actions;
export default controlNetSlice.reducer;
export const controlNetSelector = (state: RootState) => state.controlNet;