mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Fix device handling for regional masks and apply the attention mask in the FLUX double stream block.
This commit is contained in:
@@ -172,6 +172,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
packed_height=packed_h,
|
||||
packed_width=packed_w,
|
||||
dtype=inference_dtype,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
)
|
||||
neg_text_conditionings: list[FluxTextConditioning] | None = None
|
||||
if self.negative_text_conditioning is not None:
|
||||
@@ -181,6 +182,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
packed_height=packed_h,
|
||||
packed_width=packed_w,
|
||||
dtype=inference_dtype,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
)
|
||||
pos_regional_prompting_extension = RegionalPromptingExtension.from_text_conditioning(pos_text_conditionings)
|
||||
neg_regional_prompting_extension = (
|
||||
@@ -353,6 +355,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
packed_height: int,
|
||||
packed_width: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> list[FluxTextConditioning]:
|
||||
"""Load text conditioning data from a FluxConditioningField or a list of FluxConditioningFields."""
|
||||
# Normalize to a list of FluxConditioningFields.
|
||||
@@ -365,7 +368,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
assert len(cond_data.conditionings) == 1
|
||||
flux_conditioning = cond_data.conditionings[0]
|
||||
assert isinstance(flux_conditioning, FLUXConditioningInfo)
|
||||
flux_conditioning = flux_conditioning.to(dtype=dtype)
|
||||
flux_conditioning = flux_conditioning.to(dtype=dtype, device=device)
|
||||
t5_embeddings = flux_conditioning.t5_embeds
|
||||
clip_embeddings = flux_conditioning.clip_embeds
|
||||
|
||||
@@ -373,7 +376,10 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
mask: Optional[torch.Tensor] = None
|
||||
if cond_field.mask is not None:
|
||||
mask = context.tensors.load(cond_field.mask.tensor_name)
|
||||
mask = RegionalPromptingExtension.preprocess_regional_prompt_mask(mask, packed_height, packed_width, dtype)
|
||||
mask = mask.to(device=device)
|
||||
mask = RegionalPromptingExtension.preprocess_regional_prompt_mask(
|
||||
mask, packed_height, packed_width, dtype, device
|
||||
)
|
||||
|
||||
text_conditionings.append(FluxTextConditioning(t5_embeddings, clip_embeddings, mask))
|
||||
|
||||
|
||||
@@ -74,7 +74,9 @@ class CustomDoubleStreamBlockProcessor:
|
||||
"""A custom implementation of DoubleStreamBlock.forward() with additional features:
|
||||
- IP-Adapter support
|
||||
"""
|
||||
img, txt, img_q = CustomDoubleStreamBlockProcessor._double_stream_block_forward(block, img, txt, vec, pe)
|
||||
img, txt, img_q = CustomDoubleStreamBlockProcessor._double_stream_block_forward(
|
||||
block, img, txt, vec, pe, attn_mask=regional_prompting_extension.attn_mask
|
||||
)
|
||||
|
||||
# Apply IP-Adapter conditioning.
|
||||
for ip_adapter_extension in ip_adapter_extensions:
|
||||
|
||||
@@ -110,7 +110,7 @@ class RegionalPromptingExtension:
|
||||
|
||||
@staticmethod
|
||||
def preprocess_regional_prompt_mask(
|
||||
mask: Optional[torch.Tensor], packed_height: int, packed_width: int, dtype: torch.dtype
|
||||
mask: Optional[torch.Tensor], packed_height: int, packed_width: int, dtype: torch.dtype, device: torch.device
|
||||
) -> torch.Tensor:
|
||||
"""Preprocess a regional prompt mask to match the target height and width.
|
||||
If mask is None, returns a mask of all ones with the target height and width.
|
||||
@@ -123,7 +123,7 @@ class RegionalPromptingExtension:
|
||||
"""
|
||||
|
||||
if mask is None:
|
||||
return torch.ones((1, 1, packed_height * packed_width), dtype=dtype)
|
||||
return torch.ones((1, 1, packed_height * packed_width), dtype=dtype, device=device)
|
||||
|
||||
mask = to_standard_float_mask(mask, out_dtype=dtype)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user