Files
InvokeAI/invokeai/frontend/web/src/features/nodes/util/graph/addControlNetToLinearGraph.ts
2024-05-01 21:48:47 -04:00

158 lines
4.6 KiB
TypeScript

import type { RootState } from 'app/store/store';
import { selectValidControlNets } from 'features/controlAdapters/store/controlAdaptersSlice';
import type { ControlAdapterProcessorType, ControlNetConfig } from 'features/controlAdapters/store/types';
import type { ImageField } from 'features/nodes/types/common';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import type {
CollectInvocation,
ControlNetInvocation,
CoreMetadataInvocation,
NonNullableGraph,
S,
} from 'services/api/types';
import { assert } from 'tsafe';
import { CONTROL_NET_COLLECT } from './constants';
import { upsertMetadata } from './metadata';
export const addControlNetToLinearGraph = async (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string
): Promise<void> => {
const controlNetMetadata: CoreMetadataInvocation['controlnets'] = [];
const controlNets = selectValidControlNets(state.controlAdapters).filter(
({ model, processedControlImage, processorType, controlImage, isEnabled }) => {
const hasModel = Boolean(model);
const doesBaseMatch = model?.base === state.generation.model?.base;
const hasControlImage = (processedControlImage && processorType !== 'none') || controlImage;
return isEnabled && hasModel && doesBaseMatch && hasControlImage;
}
);
// The txt2img tab has special handling - its control adapters are set up in the Control Layers graph helper.
const activeTabName = activeTabNameSelector(state);
assert(activeTabName !== 'txt2img', 'Tried to use addControlNetToLinearGraph on txt2img tab');
if (controlNets.length) {
// Even though denoise_latents' control input is collection or scalar, keep it simple and always use a collect
const controlNetIterateNode: CollectInvocation = {
id: CONTROL_NET_COLLECT,
type: 'collect',
is_intermediate: true,
};
graph.nodes[CONTROL_NET_COLLECT] = controlNetIterateNode;
graph.edges.push({
source: { node_id: CONTROL_NET_COLLECT, field: 'collection' },
destination: {
node_id: baseNodeId,
field: 'control',
},
});
for (const controlNet of controlNets) {
if (!controlNet.model) {
return;
}
const {
id,
controlImage,
processedControlImage,
beginStepPct,
endStepPct,
controlMode,
resizeMode,
model,
processorType,
weight,
} = controlNet;
const controlNetNode: ControlNetInvocation = {
id: `control_net_${id}`,
type: 'controlnet',
is_intermediate: true,
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
control_mode: controlMode,
resize_mode: resizeMode,
control_model: model,
control_weight: weight,
image: buildControlImage(controlImage, processedControlImage, processorType),
};
graph.nodes[controlNetNode.id] = controlNetNode;
controlNetMetadata.push(buildControlNetMetadata(controlNet));
graph.edges.push({
source: { node_id: controlNetNode.id, field: 'control' },
destination: {
node_id: CONTROL_NET_COLLECT,
field: 'item',
},
});
}
upsertMetadata(graph, { controlnets: controlNetMetadata });
}
};
const buildControlImage = (
controlImage: string | null,
processedControlImage: string | null,
processorType: ControlAdapterProcessorType
): ImageField => {
let image: ImageField | null = null;
if (processedControlImage && processorType !== 'none') {
// We've already processed the image in the app, so we can just use the processed image
image = {
image_name: processedControlImage,
};
} else if (controlImage) {
// The control image is preprocessed
image = {
image_name: controlImage,
};
}
assert(image, 'ControlNet image is required');
return image;
};
const buildControlNetMetadata = (controlNet: ControlNetConfig): S['ControlNetMetadataField'] => {
const {
controlImage,
processedControlImage,
beginStepPct,
endStepPct,
controlMode,
resizeMode,
model,
processorType,
weight,
} = controlNet;
assert(model, 'ControlNet model is required');
const processed_image =
processedControlImage && processorType !== 'none'
? {
image_name: processedControlImage,
}
: null;
assert(controlImage, 'ControlNet image is required');
return {
control_model: model,
control_weight: weight,
control_mode: controlMode,
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
resize_mode: resizeMode,
image: {
image_name: controlImage,
},
processed_image,
};
};