feat(ui): add mini/advanced controlnet ui

This commit is contained in:
psychedelicious
2023-06-03 15:05:49 +10:00
parent 69f0ba65f1
commit d6c08ba469
18 changed files with 430 additions and 548 deletions

View File

@@ -0,0 +1,100 @@
import { RootState } from 'app/store/store';
import { forEach, size } from 'lodash-es';
import { CollectInvocation, ControlNetInvocation } from 'services/api';
import { NonNullableGraph } from '../types/types';
const CONTROL_NET_COLLECT = 'control_net_collect';
export const addControlNetToLinearGraph = (
graph: NonNullableGraph,
baseNodeId: string,
state: RootState
): void => {
const { isEnabled: isControlNetEnabled, controlNets } = state.controlNet;
// Add ControlNet
if (isControlNetEnabled) {
if (size(controlNets) > 1) {
const controlNetIterateNode: CollectInvocation = {
id: CONTROL_NET_COLLECT,
type: 'collect',
};
graph.nodes[controlNetIterateNode.id] = controlNetIterateNode;
graph.edges.push({
source: { node_id: controlNetIterateNode.id, field: 'collection' },
destination: {
node_id: baseNodeId,
field: 'control',
},
});
}
forEach(controlNets, (controlNet, index) => {
const {
controlNetId,
isEnabled,
isPreprocessed: isControlImageProcessed,
controlImage,
processedControlImage,
beginStepPct,
endStepPct,
model,
processorNode,
weight,
} = controlNet;
if (!isEnabled) {
// Skip disabled ControlNets
return;
}
const controlNetNode: ControlNetInvocation = {
id: `control_net_${controlNetId}`,
type: 'controlnet',
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
control_model: model as ControlNetInvocation['control_model'],
control_weight: weight,
};
if (processedControlImage && !isControlImageProcessed) {
// We've already processed the image in the app, so we can just use the processed image
const { image_name, image_origin } = processedControlImage;
controlNetNode.image = {
image_name,
image_origin,
};
} else if (controlImage && isControlImageProcessed) {
// The control image is preprocessed
const { image_name, image_origin } = controlImage;
controlNetNode.image = {
image_name,
image_origin,
};
} else {
// Skip ControlNets without an unprocessed image - should never happen if everything is working correctly
return;
}
graph.nodes[controlNetNode.id] = controlNetNode;
if (size(controlNets) > 1) {
graph.edges.push({
source: { node_id: controlNetNode.id, field: 'control' },
destination: {
node_id: CONTROL_NET_COLLECT,
field: 'item',
},
});
} else {
graph.edges.push({
source: { node_id: controlNetNode.id, field: 'control' },
destination: {
node_id: baseNodeId,
field: 'control',
},
});
}
});
}
};

View File

@@ -14,6 +14,7 @@ import {
import { NonNullableGraph } from 'features/nodes/types/types';
import { log } from 'app/logging/useLogger';
import { set } from 'lodash-es';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
const moduleLog = log.child({ namespace: 'nodes' });
@@ -408,5 +409,7 @@ export const buildImageToImageGraph = (state: RootState): Graph => {
});
}
addControlNetToLinearGraph(graph, LATENTS_TO_LATENTS, state);
return graph;
};

View File

@@ -1,8 +1,6 @@
import { RootState } from 'app/store/store';
import {
CollectInvocation,
CompelInvocation,
ControlNetInvocation,
Graph,
IterateInvocation,
LatentsToImageInvocation,
@@ -12,7 +10,7 @@ import {
TextToLatentsInvocation,
} from 'services/api';
import { NonNullableGraph } from 'features/nodes/types/types';
import { forEach, size } from 'lodash-es';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
const POSITIVE_CONDITIONING = 'positive_conditioning';
const NEGATIVE_CONDITIONING = 'negative_conditioning';
@@ -22,7 +20,6 @@ const NOISE = 'noise';
const RANDOM_INT = 'rand_int';
const RANGE_OF_SIZE = 'range_of_size';
const ITERATE = 'iterate';
const CONTROL_NET_COLLECT = 'control_net_collect';
/**
* Builds the Text to Image tab graph.
@@ -42,8 +39,6 @@ export const buildTextToImageGraph = (state: RootState): Graph => {
shouldRandomizeSeed,
} = state.generation;
const { isEnabled: isControlNetEnabled, controlNets } = state.controlNet;
const graph: NonNullableGraph = {
nodes: {},
edges: [],
@@ -315,91 +310,7 @@ export const buildTextToImageGraph = (state: RootState): Graph => {
});
}
// Add ControlNet
if (isControlNetEnabled) {
if (size(controlNets) > 1) {
const controlNetIterateNode: CollectInvocation = {
id: CONTROL_NET_COLLECT,
type: 'collect',
};
graph.nodes[controlNetIterateNode.id] = controlNetIterateNode;
graph.edges.push({
source: { node_id: controlNetIterateNode.id, field: 'collection' },
destination: {
node_id: TEXT_TO_LATENTS,
field: 'control',
},
});
}
forEach(controlNets, (controlNet, index) => {
const {
controlNetId,
isEnabled,
isControlImageProcessed,
controlImage,
processedControlImage,
beginStepPct,
endStepPct,
model,
processorNode,
weight,
} = controlNet;
if (!isEnabled) {
// Skip disabled ControlNets
return;
}
const controlNetNode: ControlNetInvocation = {
id: `control_net_${controlNetId}`,
type: 'controlnet',
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
control_model: model as ControlNetInvocation['control_model'],
control_weight: weight,
};
if (processedControlImage && !isControlImageProcessed) {
// We've already processed the image in the app, so we can just use the processed image
const { image_name, image_origin } = processedControlImage;
controlNetNode.image = {
image_name,
image_origin,
};
} else if (controlImage && isControlImageProcessed) {
// The control image is preprocessed
const { image_name, image_origin } = controlImage;
controlNetNode.image = {
image_name,
image_origin,
};
} else {
// Skip ControlNets without an unprocessed image - should never happen if everything is working correctly
return;
}
graph.nodes[controlNetNode.id] = controlNetNode;
if (size(controlNets) > 1) {
graph.edges.push({
source: { node_id: controlNetNode.id, field: 'control' },
destination: {
node_id: CONTROL_NET_COLLECT,
field: 'item',
},
});
} else {
graph.edges.push({
source: { node_id: controlNetNode.id, field: 'control' },
destination: {
node_id: TEXT_TO_LATENTS,
field: 'control',
},
});
}
});
}
addControlNetToLinearGraph(graph, TEXT_TO_LATENTS, state);
return graph;
};