Files
InvokeAI/invokeai/frontend/web/src/features/nodes/util/graph/addIPAdapterToLinearGraph.ts
2024-04-30 08:10:59 -04:00

146 lines
5.0 KiB
TypeScript

import type { RootState } from 'app/store/store';
import { selectValidIPAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
import type { IPAdapterConfig } from 'features/controlAdapters/store/types';
import { isIPAdapterLayer, isRegionalGuidanceLayer } from 'features/controlLayers/store/controlLayersSlice';
import type { ImageField } from 'features/nodes/types/common';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { differenceWith, intersectionWith } from 'lodash-es';
import type {
CollectInvocation,
CoreMetadataInvocation,
IPAdapterInvocation,
NonNullableGraph,
S,
} from 'services/api/types';
import { assert } from 'tsafe';
import { IP_ADAPTER_COLLECT } from './constants';
import { upsertMetadata } from './metadata';
const getIPAdapters = (state: RootState) => {
// Start with the valid IP adapters
const validIPAdapters = selectValidIPAdapters(state.controlAdapters).filter(({ model, controlImage, isEnabled }) => {
const hasModel = Boolean(model);
const doesBaseMatch = model?.base === state.generation.model?.base;
const hasControlImage = controlImage;
return isEnabled && hasModel && doesBaseMatch && hasControlImage;
});
// Masked IP adapters are handled in the graph helper for regional control - skip them here
const maskedIPAdapterIds = state.controlLayers.present.layers
.filter(isRegionalGuidanceLayer)
.map((l) => l.ipAdapterIds)
.flat();
const nonMaskedIPAdapters = differenceWith(validIPAdapters, maskedIPAdapterIds, (a, b) => a.id === b);
// txt2img tab has special handling - it uses layers exclusively, while the other tabs use the older control adapters
// accordion. We need to filter the list of valid IP adapters according to the tab.
const activeTabName = activeTabNameSelector(state);
if (activeTabName === 'txt2img') {
// If we are on the t2i tab, we only want to add the IP adapters that are used in unmasked IP Adapter layers
// Collect all IP Adapter ids for enabled IP adapter layers
const layerIPAdapterIds = state.controlLayers.present.layers
.filter(isIPAdapterLayer)
.filter((l) => l.isEnabled)
.map((l) => l.ipAdapterId);
return intersectionWith(nonMaskedIPAdapters, layerIPAdapterIds, (a, b) => a.id === b);
} else {
// Else, we want to exclude the IP adapters that are used in IP Adapter layers
// Collect all IP Adapter ids for enabled IP adapter layers
const layerIPAdapterIds = state.controlLayers.present.layers.filter(isIPAdapterLayer).map((l) => l.ipAdapterId);
return differenceWith(nonMaskedIPAdapters, layerIPAdapterIds, (a, b) => a.id === b);
}
};
export const addIPAdapterToLinearGraph = async (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string
): Promise<void> => {
const ipAdapters = getIPAdapters(state);
if (ipAdapters.length) {
// Even though denoise_latents' ip adapter input is collection or scalar, keep it simple and always use a collect
const ipAdapterCollectNode: CollectInvocation = {
id: IP_ADAPTER_COLLECT,
type: 'collect',
is_intermediate: true,
};
graph.nodes[IP_ADAPTER_COLLECT] = ipAdapterCollectNode;
graph.edges.push({
source: { node_id: IP_ADAPTER_COLLECT, field: 'collection' },
destination: {
node_id: baseNodeId,
field: 'ip_adapter',
},
});
const ipAdapterMetdata: CoreMetadataInvocation['ipAdapters'] = [];
for (const ipAdapter of ipAdapters) {
if (!ipAdapter.model) {
return;
}
const { id, weight, model, clipVisionModel, method, beginStepPct, endStepPct, controlImage } = ipAdapter;
assert(controlImage, 'IP Adapter image is required');
const ipAdapterNode: IPAdapterInvocation = {
id: `ip_adapter_${id}`,
type: 'ip_adapter',
is_intermediate: true,
weight: weight,
method: method,
ip_adapter_model: model,
clip_vision_model: clipVisionModel,
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
image: {
image_name: controlImage,
},
};
graph.nodes[ipAdapterNode.id] = ipAdapterNode;
ipAdapterMetdata.push(buildIPAdapterMetadata(ipAdapter));
graph.edges.push({
source: { node_id: ipAdapterNode.id, field: 'ip_adapter' },
destination: {
node_id: ipAdapterCollectNode.id,
field: 'item',
},
});
}
upsertMetadata(graph, { ipAdapters: ipAdapterMetdata });
}
};
const buildIPAdapterMetadata = (ipAdapter: IPAdapterConfig): S['IPAdapterMetadataField'] => {
const { controlImage, beginStepPct, endStepPct, model, clipVisionModel, method, weight } = ipAdapter;
assert(model, 'IP Adapter model is required');
let image: ImageField | null = null;
if (controlImage) {
image = {
image_name: controlImage,
};
}
assert(image, 'IP Adapter image is required');
return {
ip_adapter_model: model,
clip_vision_model: clipVisionModel,
weight,
method,
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
image,
};
};