mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Add FLUX Control LoRA weight param (#7452)
## Summary Add the ability to control the weight of a FLUX Control LoRA. ## Example Original image: <div style="display: flex; gap: 10px;"> <img src="https://github.com/user-attachments/assets/4a2d9f4a-b58b-4df6-af90-67b018763a38" alt="Image 1" width="300"/> </div> Prompt: `a scarecrow playing tennis` Weights: 0.4, 0.6, 0.8, 1.0 <div style="display: flex; gap: 10px;"> <img src="https://github.com/user-attachments/assets/62b83fd6-46ce-460a-8d51-9c2cda9b05c9" alt="Image 1" width="300"/> <img src="https://github.com/user-attachments/assets/75442207-1538-46bc-9d6b-08ac5c235c93" alt="Image 2" width="300"/> </div> <div style="display: flex; gap: 10px;"> <img src="https://github.com/user-attachments/assets/4a9dc9ea-9757-4965-837e-197fc9243007" alt="Image 1" width="300"/> <img src="https://github.com/user-attachments/assets/846f6918-ca82-4482-8c19-19172752fa8c" alt="Image 2" width="300"/> </div> ## QA Instructions - [x] weight control changes strength of control image - [x] Test that results match across both quantized and non-quantized. ## Merge Plan **_Do not merge this PR yet._** 1. Merge #7450 2. Merge #7446 3. Change target branch to main 4. Merge this branch. ## Checklist - [ ] _The PR has a short but descriptive title, suitable for a changelog_ - [ ] _Tests added / updated (if applicable)_ - [ ] _Documentation added / updated (if applicable)_ - [ ] _Updated `What's New` copy (if doing a release after this PR)_
This commit is contained in:
@@ -24,7 +24,7 @@ class FluxControlLoRALoaderOutput(BaseInvocationOutput):
|
||||
title="Flux Control LoRA",
|
||||
tags=["lora", "model", "flux"],
|
||||
category="model",
|
||||
version="1.0.0",
|
||||
version="1.1.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxControlLoRALoaderInvocation(BaseInvocation):
|
||||
@@ -34,6 +34,7 @@ class FluxControlLoRALoaderInvocation(BaseInvocation):
|
||||
description=FieldDescriptions.control_lora_model, title="Control LoRA", ui_type=UIType.ControlLoRAModel
|
||||
)
|
||||
image: ImageField = InputField(description="The image to encode.")
|
||||
weight: float = InputField(description="The weight of the LoRA.", default=1.0)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxControlLoRALoaderOutput:
|
||||
if not context.models.exists(self.lora.key):
|
||||
@@ -43,6 +44,6 @@ class FluxControlLoRALoaderInvocation(BaseInvocation):
|
||||
control_lora=ControlLoRAField(
|
||||
lora=self.lora,
|
||||
img=self.image,
|
||||
weight=1,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -15,8 +15,10 @@ class SetParameterLayer(BaseLayerPatch):
|
||||
self.param_name = param_name
|
||||
|
||||
def get_parameters(self, orig_module: torch.nn.Module, weight: float) -> dict[str, torch.Tensor]:
|
||||
# Note: We intentionally ignore the weight parameter here. This matches the behavior in the official FLUX
|
||||
# Control LoRA implementation.
|
||||
diff = self.weight - orig_module.get_parameter(self.param_name)
|
||||
return {self.param_name: diff * weight}
|
||||
return {self.param_name: diff}
|
||||
|
||||
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
|
||||
@@ -162,7 +162,7 @@ export const ControlLayerControlAdapter = memo(() => {
|
||||
/>
|
||||
<input {...uploadApi.getUploadInputProps()} />
|
||||
</Flex>
|
||||
{controlAdapter.type !== 'control_lora' && <Weight weight={controlAdapter.weight} onChange={onChangeWeight} />}
|
||||
<Weight weight={controlAdapter.weight} onChange={onChangeWeight} />
|
||||
{controlAdapter.type !== 'control_lora' && (
|
||||
<BeginEndStepPct beginEndStepPct={controlAdapter.beginEndStepPct} onChange={onChangeBeginEndStepPct} />
|
||||
)}
|
||||
|
||||
@@ -74,6 +74,7 @@ import {
|
||||
getReferenceImageState,
|
||||
getRegionalGuidanceState,
|
||||
imageDTOToImageWithDims,
|
||||
initialControlLoRA,
|
||||
initialControlNet,
|
||||
initialIPAdapter,
|
||||
initialT2IAdapter,
|
||||
@@ -462,38 +463,64 @@ export const canvasSlice = createSlice({
|
||||
}
|
||||
layer.controlAdapter.model = zModelIdentifierField.parse(modelConfig);
|
||||
|
||||
// When converting between control layer types, we may need to add or remove properties. For example, ControlNet
|
||||
// has a control mode, while T2I Adapter does not - otherwise they are the same.
|
||||
|
||||
switch (layer.controlAdapter.model.type) {
|
||||
// Converting to T2I adapter from...
|
||||
case 't2i_adapter': {
|
||||
if (layer.controlAdapter.type === 'controlnet') {
|
||||
// T2I Adapters have all the ControlNet properties, minus control mode - strip it
|
||||
const { controlMode: _, ...rest } = layer.controlAdapter;
|
||||
const t2iAdapterConfig: T2IAdapterConfig = { ...rest, type: 't2i_adapter' };
|
||||
const t2iAdapterConfig: T2IAdapterConfig = { ...initialT2IAdapter, ...rest, type: 't2i_adapter' };
|
||||
layer.controlAdapter = t2iAdapterConfig;
|
||||
} else if (layer.controlAdapter.type === 'control_lora') {
|
||||
const t2iAdapterConfig: T2IAdapterConfig = { ...layer.controlAdapter, ...initialT2IAdapter };
|
||||
// Control LoRAs have only model and weight
|
||||
const t2iAdapterConfig: T2IAdapterConfig = {
|
||||
...initialT2IAdapter,
|
||||
...layer.controlAdapter,
|
||||
type: 't2i_adapter',
|
||||
};
|
||||
layer.controlAdapter = t2iAdapterConfig;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
// Converting to ControlNet from...
|
||||
case 'controlnet': {
|
||||
if (layer.controlAdapter.type === 't2i_adapter') {
|
||||
// ControlNets have all the T2I Adapter properties, plus control mode
|
||||
const controlNetConfig: ControlNetConfig = {
|
||||
...initialControlNet,
|
||||
...layer.controlAdapter,
|
||||
type: 'controlnet',
|
||||
controlMode: initialControlNet.controlMode,
|
||||
};
|
||||
layer.controlAdapter = controlNetConfig;
|
||||
} else if (layer.controlAdapter.type === 'control_lora') {
|
||||
const controlNetConfig: ControlNetConfig = { ...layer.controlAdapter, ...initialControlNet };
|
||||
// ControlNets have all the Control LoRA properties, plus control mode and begin/end step pct
|
||||
const controlNetConfig: ControlNetConfig = {
|
||||
...initialControlNet,
|
||||
...layer.controlAdapter,
|
||||
type: 'controlnet',
|
||||
};
|
||||
layer.controlAdapter = controlNetConfig;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
// Converting to ControlLoRA from...
|
||||
case 'control_lora': {
|
||||
const controlLoraConfig: ControlLoRAConfig = { ...layer.controlAdapter, type: 'control_lora' };
|
||||
layer.controlAdapter = controlLoraConfig;
|
||||
|
||||
if (layer.controlAdapter.type === 'controlnet') {
|
||||
// We only need the model and weight for Control LoRA
|
||||
const { model, weight } = layer.controlAdapter;
|
||||
const controlNetConfig: ControlLoRAConfig = { ...initialControlLoRA, model, weight };
|
||||
layer.controlAdapter = controlNetConfig;
|
||||
} else if (layer.controlAdapter.type === 't2i_adapter') {
|
||||
// We only need the model and weight for Control LoRA
|
||||
const { model, weight } = layer.controlAdapter;
|
||||
const t2iAdapterConfig: ControlLoRAConfig = { ...initialControlLoRA, model, weight };
|
||||
layer.controlAdapter = t2iAdapterConfig;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -518,7 +545,7 @@ export const canvasSlice = createSlice({
|
||||
) => {
|
||||
const { entityIdentifier, weight } = action.payload;
|
||||
const layer = selectEntity(state, entityIdentifier);
|
||||
if (!layer || !layer.controlAdapter || layer.controlAdapter.type === 'control_lora') {
|
||||
if (!layer || !layer.controlAdapter) {
|
||||
return;
|
||||
}
|
||||
layer.controlAdapter.weight = weight;
|
||||
|
||||
@@ -298,6 +298,7 @@ export type T2IAdapterConfig = z.infer<typeof zT2IAdapterConfig>;
|
||||
|
||||
const zControlLoRAConfig = z.object({
|
||||
type: z.literal('control_lora'),
|
||||
weight: z.number().gte(-1).lte(2),
|
||||
model: zServerValidatedModelIdentifierField.nullable(),
|
||||
});
|
||||
export type ControlLoRAConfig = z.infer<typeof zControlLoRAConfig>;
|
||||
|
||||
@@ -7,6 +7,7 @@ import type {
|
||||
CanvasRasterLayerState,
|
||||
CanvasReferenceImageState,
|
||||
CanvasRegionalGuidanceState,
|
||||
ControlLoRAConfig,
|
||||
ControlNetConfig,
|
||||
ImageWithDims,
|
||||
IPAdapterConfig,
|
||||
@@ -82,6 +83,11 @@ export const initialControlNet: ControlNetConfig = {
|
||||
beginEndStepPct: [0, 0.75],
|
||||
controlMode: 'balanced',
|
||||
};
|
||||
export const initialControlLoRA: ControlLoRAConfig = {
|
||||
type: 'control_lora',
|
||||
model: null,
|
||||
weight: 0.75,
|
||||
};
|
||||
|
||||
export const getReferenceImageState = (
|
||||
id: string,
|
||||
|
||||
@@ -207,7 +207,7 @@ const addControlLoRAToGraph = (
|
||||
) => {
|
||||
const { id, controlAdapter } = layer;
|
||||
assert(controlAdapter.type === 'control_lora');
|
||||
const { model } = controlAdapter;
|
||||
const { model, weight } = controlAdapter;
|
||||
assert(model !== null);
|
||||
const { image_name } = imageDTO;
|
||||
|
||||
@@ -216,6 +216,7 @@ const addControlLoRAToGraph = (
|
||||
type: 'flux_control_lora_loader',
|
||||
lora: model,
|
||||
image: { image_name },
|
||||
weight: weight,
|
||||
});
|
||||
|
||||
g.addEdge(controlLoRA, 'control_lora', denoise, 'control_lora');
|
||||
|
||||
@@ -6708,6 +6708,12 @@ export type components = {
|
||||
* @default null
|
||||
*/
|
||||
image?: components["schemas"]["ImageField"];
|
||||
/**
|
||||
* Weight
|
||||
* @description The weight of the LoRA.
|
||||
* @default 1
|
||||
*/
|
||||
weight?: number;
|
||||
/**
|
||||
* type
|
||||
* @default flux_control_lora_loader
|
||||
@@ -6722,11 +6728,11 @@ export type components = {
|
||||
*/
|
||||
FluxControlLoRALoaderOutput: {
|
||||
/**
|
||||
* Flux Control Lora
|
||||
* Flux Control LoRA
|
||||
* @description Control LoRAs to apply on model loading
|
||||
* @default null
|
||||
*/
|
||||
control_lora: components["schemas"]["ControlLoRAField"] | null;
|
||||
control_lora: components["schemas"]["ControlLoRAField"];
|
||||
/**
|
||||
* type
|
||||
* @default flux_control_lora_loader_output
|
||||
@@ -6926,7 +6932,7 @@ export type components = {
|
||||
*/
|
||||
transformer?: components["schemas"]["TransformerField"];
|
||||
/**
|
||||
* Control Lora
|
||||
* Control LoRA
|
||||
* @description Control LoRA model to load
|
||||
* @default null
|
||||
*/
|
||||
|
||||
Reference in New Issue
Block a user