Expanded styles & updated UI

This commit is contained in:
Kent Keirsey
2025-05-18 12:42:14 -04:00
committed by psychedelicious
parent d709040f4b
commit b02ea1a898
6 changed files with 65 additions and 10 deletions

View File

@@ -95,7 +95,7 @@ class IPAdapterInvocation(BaseInvocation):
weight: Union[float, List[float]] = InputField(
default=1, description="The weight given to the IP-Adapter", title="Weight"
)
method: Literal["full", "style", "composition"] = InputField(
method: Literal["full", "style", "composition", "style_strong", "style_precise"] = InputField(
default="full", description="The method to apply the IP-Adapter"
)
begin_step_percent: float = InputField(
@@ -133,12 +133,13 @@ class IPAdapterInvocation(BaseInvocation):
image_encoder_model_name = image_encoder_starter_model.name
image_encoder_model = self.get_clip_image_encoder(context, image_encoder_model_id, image_encoder_model_name)
negative_blocks: List[str] = []
if self.method == "style":
if ip_adapter_info.base == "sd-1":
target_blocks = ["up_blocks.1"]
elif ip_adapter_info.base == "sdxl":
target_blocks = ["up_blocks.0.attentions.1", "down_blocks.2.attentions.1"]
target_blocks = ["up_blocks.0.attentions.1"]
else:
raise ValueError(f"Unsupported IP-Adapter base type: '{ip_adapter_info.base}'.")
elif self.method == "composition":
@@ -148,6 +149,38 @@ class IPAdapterInvocation(BaseInvocation):
target_blocks = ["down_blocks.2.attentions.1"]
else:
raise ValueError(f"Unsupported IP-Adapter base type: '{ip_adapter_info.base}'.")
elif self.method == "style_precise":
if ip_adapter_info.base == "sd-1":
target_blocks = ["up_blocks.1","down_blocks.2","mid_block"]
elif ip_adapter_info.base == "sdxl":
target_blocks = ["up_blocks.0.attentions.1","down_blocks.2.attentions.1"]
else:
raise ValueError(f"Unsupported IP-Adapter base type: '{ip_adapter_info.base}'.")
elif self.method == "style_strong":
if ip_adapter_info.base == "sd-1":
target_blocks = ["up_blocks.0", "up_blocks.1", "up_blocks.2", "down_blocks.0", "down_blocks.1"]
elif ip_adapter_info.base == "sdxl":
target_blocks = [
"up_blocks.0.attentions.1",
"up_blocks.1.attentions.1",
"up_blocks.2.attentions.1",
"up_blocks.0.attentions.2",
"up_blocks.1.attentions.2",
"up_blocks.2.attentions.2",
"up_blocks.0.attentions.0",
"up_blocks.1.attentions.0",
"up_blocks.2.attentions.0",
"down_blocks.0.attentions.0",
"down_blocks.0.attentions.1",
"down_blocks.0.attentions.2",
"down_blocks.1.attentions.0",
"down_blocks.1.attentions.1",
"down_blocks.1.attentions.2",
"down_blocks.2.attentions.0",
"down_blocks.2.attentions.2",
]
else:
raise ValueError(f"Unsupported IP-Adapter base type: '{ip_adapter_info.base}'.")
elif self.method == "full":
target_blocks = ["block"]
else:

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
import math
from dataclasses import dataclass
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
@@ -104,12 +104,24 @@ class IPAdapterConditioningInfo:
@dataclass
class IPAdapterData:
"""Data class for IP-Adapter configuration.
Attributes:
ip_adapter_model: The IP-Adapter model to use.
ip_adapter_conditioning: The IP-Adapter conditioning data.
mask: The mask to apply to the IP-Adapter conditioning.
target_blocks: List of target attention block names to apply IP-Adapter to.
negative_blocks: List of target attention block names that should use negative attention.
weight: The weight to apply to the IP-Adapter conditioning.
begin_step_percent: The percentage of steps at which to start applying the IP-Adapter.
end_step_percent: The percentage of steps at which to stop applying the IP-Adapter.
method: The method to use for applying the IP-Adapter ('full', 'style', 'composition').
"""
ip_adapter_model: IPAdapter
ip_adapter_conditioning: IPAdapterConditioningInfo
mask: torch.Tensor
target_blocks: List[str]
# Either a single weight applied to all steps, or a list of weights for each step.
negative_blocks: List[str] = field(default_factory=list)
weight: Union[float, List[float]] = 1.0
begin_step_percent: float = 0.0
end_step_percent: float = 1.0

View File

@@ -12,8 +12,8 @@ from invokeai.backend.stable_diffusion.diffusion.custom_atttention import (
class UNetIPAdapterData(TypedDict):
ip_adapter: IPAdapter
target_blocks: List[str]
method: str
target_blocks: List[str] # Blocks where IP-Adapter should be applied
method: str # Style or other method type
class UNetAttentionPatcher:
@@ -44,7 +44,7 @@ class UNetAttentionPatcher:
for block in ip_adapter["target_blocks"]:
if block in name:
skip = False
negative = (ip_adapter["method"] == 'style' and block == "down_blocks.2.attentions.1")
negative = (ip_adapter["method"] == 'style_precise' and block == "down_blocks.2.attentions.1")
break
ip_adapter_attention_weights: IPAdapterAttentionWeights = IPAdapterAttentionWeights(
ip_adapter_weights=ip_adapter_weights, skip=skip, negative=negative

View File

@@ -35,6 +35,16 @@ export const IPAdapterMethod = memo(({ method, onChange }: Props) => {
value: 'composition',
description: shouldShowModelDescriptions ? t('controlLayers.ipAdapterMethod.compositionDesc') : undefined,
},
{
label: t('controlLayers.ipAdapterMethod.styleStrong'),
value: 'style_strong',
description: shouldShowModelDescriptions ? t('controlLayers.ipAdapterMethod.styleStrongDesc') : undefined,
},
{
label: t('controlLayers.ipAdapterMethod.stylePrecise'),
value: 'style_precise',
description: shouldShowModelDescriptions ? t('controlLayers.ipAdapterMethod.stylePreciseDesc') : undefined,
},
],
[t, shouldShowModelDescriptions]
);

View File

@@ -50,7 +50,7 @@ const zCLIPVisionModelV2 = z.enum(['ViT-H', 'ViT-G', 'ViT-L']);
export type CLIPVisionModelV2 = z.infer<typeof zCLIPVisionModelV2>;
export const isCLIPVisionModelV2 = (v: unknown): v is CLIPVisionModelV2 => zCLIPVisionModelV2.safeParse(v).success;
const zIPMethodV2 = z.enum(['full', 'style', 'composition']);
const zIPMethodV2 = z.enum(['full', 'style', 'composition', 'style_strong', 'style_precise']);
export type IPMethodV2 = z.infer<typeof zIPMethodV2>;
export const isIPMethodV2 = (v: unknown): v is IPMethodV2 => zIPMethodV2.safeParse(v).success;

View File

@@ -147,7 +147,7 @@ export const zIPAdapterField = z.object({
image: zImageField,
ip_adapter_model: zModelIdentifierField,
weight: z.number(),
method: z.enum(['full', 'style', 'composition']),
method: z.enum(['full', 'style', 'composition', 'style_strong', 'style_precise']),
begin_step_percent: z.number().optional(),
end_step_percent: z.number().optional(),
});