possibly a working FLUX controlnet graph

This commit is contained in:
Mary Hipp
2024-10-09 15:41:43 -04:00
parent 8b1ef4b902
commit 63a2e17f6b
3 changed files with 29 additions and 4 deletions

View File

@@ -16,13 +16,13 @@ import {
controlLayerModelChanged,
controlLayerWeightChanged,
} from 'features/controlLayers/store/canvasSlice';
import { selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/store/selectors';
import type { CanvasEntityIdentifier, ControlModeV2 } from 'features/controlLayers/store/types';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiBoundingBoxBold, PiShootingStarBold, PiUploadBold } from 'react-icons/pi';
import type { ControlNetModelConfig, PostUploadAction, T2IAdapterModelConfig } from 'services/api/types';
import { selectIsFLUX } from '../../store/paramsSlice';
const useControlLayerControlAdapter = (entityIdentifier: CanvasEntityIdentifier<'control_layer'>) => {
const selectControlAdapter = useMemo(

View File

@@ -110,16 +110,21 @@ const addControlNetToGraph = (
const controlNet = g.addNode({
id: `control_net_${id}`,
type: 'controlnet',
type: model.base === 'flux' ? 'flux_controlnet' : 'controlnet',
begin_step_percent: beginEndStepPct[0],
end_step_percent: beginEndStepPct[1],
control_mode: controlMode,
control_mode: model.base === 'flux' ? undefined : controlMode,
resize_mode: 'just_resize',
control_model: model,
control_weight: weight,
image: { image_name },
});
g.addEdge(controlNet, 'control', collector, 'item');
if (controlNet.type === 'flux_controlnet') {
g.addEdge(controlNet, 'controlnet', collector, 'item');
} else {
g.addEdge(controlNet, 'control', collector, 'item');
}
};
const addT2IAdapterToGraph = (

View File

@@ -19,6 +19,8 @@ import type { Invocation } from 'services/api/types';
import { isNonRefinerMainModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import { addControlNets } from './addControlAdapters';
const log = logger('system');
export const buildFLUXGraph = async (
@@ -177,6 +179,24 @@ export const buildFLUXGraph = async (
);
}
const controlNetCollector = g.addNode({
type: 'collect',
id: getPrefixedId('control_net_collector'),
});
const controlNetResult = await addControlNets(
manager,
canvas.controlLayers.entities,
g,
canvas.bbox.rect,
controlNetCollector,
modelConfig.base
);
if (controlNetResult.addedControlNets > 0) {
g.addEdge(controlNetCollector, 'collection', noise, 'controlnet');
} else {
g.deleteNode(controlNetCollector.id);
}
if (state.system.shouldUseNSFWChecker) {
canvasOutput = addNSFWChecker(g, canvasOutput);
}