mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-07 03:04:56 -05:00
possibly a working FLUX controlnet graph
This commit is contained in:
@@ -16,13 +16,13 @@ import {
|
||||
controlLayerModelChanged,
|
||||
controlLayerWeightChanged,
|
||||
} from 'features/controlLayers/store/canvasSlice';
|
||||
import { selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/store/selectors';
|
||||
import type { CanvasEntityIdentifier, ControlModeV2 } from 'features/controlLayers/store/types';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiBoundingBoxBold, PiShootingStarBold, PiUploadBold } from 'react-icons/pi';
|
||||
import type { ControlNetModelConfig, PostUploadAction, T2IAdapterModelConfig } from 'services/api/types';
|
||||
import { selectIsFLUX } from '../../store/paramsSlice';
|
||||
|
||||
const useControlLayerControlAdapter = (entityIdentifier: CanvasEntityIdentifier<'control_layer'>) => {
|
||||
const selectControlAdapter = useMemo(
|
||||
|
||||
@@ -110,16 +110,21 @@ const addControlNetToGraph = (
|
||||
|
||||
const controlNet = g.addNode({
|
||||
id: `control_net_${id}`,
|
||||
type: 'controlnet',
|
||||
type: model.base === 'flux' ? 'flux_controlnet' : 'controlnet',
|
||||
begin_step_percent: beginEndStepPct[0],
|
||||
end_step_percent: beginEndStepPct[1],
|
||||
control_mode: controlMode,
|
||||
control_mode: model.base === 'flux' ? undefined : controlMode,
|
||||
resize_mode: 'just_resize',
|
||||
control_model: model,
|
||||
control_weight: weight,
|
||||
image: { image_name },
|
||||
});
|
||||
g.addEdge(controlNet, 'control', collector, 'item');
|
||||
|
||||
if (controlNet.type === 'flux_controlnet') {
|
||||
g.addEdge(controlNet, 'controlnet', collector, 'item');
|
||||
} else {
|
||||
g.addEdge(controlNet, 'control', collector, 'item');
|
||||
}
|
||||
};
|
||||
|
||||
const addT2IAdapterToGraph = (
|
||||
|
||||
@@ -19,6 +19,8 @@ import type { Invocation } from 'services/api/types';
|
||||
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import { addControlNets } from './addControlAdapters';
|
||||
|
||||
const log = logger('system');
|
||||
|
||||
export const buildFLUXGraph = async (
|
||||
@@ -177,6 +179,24 @@ export const buildFLUXGraph = async (
|
||||
);
|
||||
}
|
||||
|
||||
const controlNetCollector = g.addNode({
|
||||
type: 'collect',
|
||||
id: getPrefixedId('control_net_collector'),
|
||||
});
|
||||
const controlNetResult = await addControlNets(
|
||||
manager,
|
||||
canvas.controlLayers.entities,
|
||||
g,
|
||||
canvas.bbox.rect,
|
||||
controlNetCollector,
|
||||
modelConfig.base
|
||||
);
|
||||
if (controlNetResult.addedControlNets > 0) {
|
||||
g.addEdge(controlNetCollector, 'collection', noise, 'controlnet');
|
||||
} else {
|
||||
g.deleteNode(controlNetCollector.id);
|
||||
}
|
||||
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
canvasOutput = addNSFWChecker(g, canvasOutput);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user