Fix device handling for regional masks and apply the attention mask in the FLUX double stream block.

This commit is contained in:
Ryan Dick
2024-11-25 16:02:03 +00:00
parent 2c23b8414c
commit 3741a6f5e0
3 changed files with 13 additions and 5 deletions

View File

@@ -172,6 +172,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
packed_height=packed_h,
packed_width=packed_w,
dtype=inference_dtype,
device=TorchDevice.choose_torch_device(),
)
neg_text_conditionings: list[FluxTextConditioning] | None = None
if self.negative_text_conditioning is not None:
@@ -181,6 +182,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
packed_height=packed_h,
packed_width=packed_w,
dtype=inference_dtype,
device=TorchDevice.choose_torch_device(),
)
pos_regional_prompting_extension = RegionalPromptingExtension.from_text_conditioning(pos_text_conditionings)
neg_regional_prompting_extension = (
@@ -353,6 +355,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
packed_height: int,
packed_width: int,
dtype: torch.dtype,
device: torch.device,
) -> list[FluxTextConditioning]:
"""Load text conditioning data from a FluxConditioningField or a list of FluxConditioningFields."""
# Normalize to a list of FluxConditioningFields.
@@ -365,7 +368,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
assert len(cond_data.conditionings) == 1
flux_conditioning = cond_data.conditionings[0]
assert isinstance(flux_conditioning, FLUXConditioningInfo)
flux_conditioning = flux_conditioning.to(dtype=dtype)
flux_conditioning = flux_conditioning.to(dtype=dtype, device=device)
t5_embeddings = flux_conditioning.t5_embeds
clip_embeddings = flux_conditioning.clip_embeds
@@ -373,7 +376,10 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
mask: Optional[torch.Tensor] = None
if cond_field.mask is not None:
mask = context.tensors.load(cond_field.mask.tensor_name)
mask = RegionalPromptingExtension.preprocess_regional_prompt_mask(mask, packed_height, packed_width, dtype)
mask = mask.to(device=device)
mask = RegionalPromptingExtension.preprocess_regional_prompt_mask(
mask, packed_height, packed_width, dtype, device
)
text_conditionings.append(FluxTextConditioning(t5_embeddings, clip_embeddings, mask))

View File

@@ -74,7 +74,9 @@ class CustomDoubleStreamBlockProcessor:
"""A custom implementation of DoubleStreamBlock.forward() with additional features:
- IP-Adapter support
"""
img, txt, img_q = CustomDoubleStreamBlockProcessor._double_stream_block_forward(block, img, txt, vec, pe)
img, txt, img_q = CustomDoubleStreamBlockProcessor._double_stream_block_forward(
block, img, txt, vec, pe, attn_mask=regional_prompting_extension.attn_mask
)
# Apply IP-Adapter conditioning.
for ip_adapter_extension in ip_adapter_extensions:

View File

@@ -110,7 +110,7 @@ class RegionalPromptingExtension:
@staticmethod
def preprocess_regional_prompt_mask(
mask: Optional[torch.Tensor], packed_height: int, packed_width: int, dtype: torch.dtype
mask: Optional[torch.Tensor], packed_height: int, packed_width: int, dtype: torch.dtype, device: torch.device
) -> torch.Tensor:
"""Preprocess a regional prompt mask to match the target height and width.
If mask is None, returns a mask of all ones with the target height and width.
@@ -123,7 +123,7 @@ class RegionalPromptingExtension:
"""
if mask is None:
return torch.ones((1, 1, packed_height * packed_width), dtype=dtype)
return torch.ones((1, 1, packed_height * packed_width), dtype=dtype, device=device)
mask = to_standard_float_mask(mask, out_dtype=dtype)