Make negative_text_conditioning nullable on FLUX Denoise invocation.

This commit is contained in:
Ryan Dick
2024-10-18 20:31:27 +00:00
parent 371742d8f9
commit 6df4ee5fc8
2 changed files with 23 additions and 12 deletions

View File

@@ -82,8 +82,10 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
positive_text_conditioning: FluxConditioningField = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
)
negative_text_conditioning: FluxConditioningField = InputField(
description=FieldDescriptions.negative_cond, input=Input.Connection
negative_text_conditioning: FluxConditioningField | None = InputField(
default=None,
description="Negative conditioning tensor. Can be None if cfg_scale is 1.0.",
input=Input.Connection,
)
# TODO(ryand): Add cfg_scale range validation.
cfg_scale: float | list[float] = InputField(default=1.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
@@ -136,9 +138,12 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
pos_t5_embeddings, pos_clip_embeddings = self._load_text_conditioning(
context, self.positive_text_conditioning.conditioning_name, inference_dtype
)
neg_t5_embeddings, neg_clip_embeddings = self._load_text_conditioning(
context, self.negative_text_conditioning.conditioning_name, inference_dtype
)
neg_t5_embeddings: torch.Tensor | None = None
neg_clip_embeddings: torch.Tensor | None = None
if self.negative_text_conditioning is not None:
neg_t5_embeddings, neg_clip_embeddings = self._load_text_conditioning(
context, self.negative_text_conditioning.conditioning_name, inference_dtype
)
# Load the input latents, if provided.
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
@@ -203,10 +208,12 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
pos_txt_ids = torch.zeros(
pos_bs, pos_t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device()
)
neg_bs, neg_t5_seq_len, _ = neg_t5_embeddings.shape
neg_txt_ids = torch.zeros(
neg_bs, neg_t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device()
)
neg_txt_ids: torch.Tensor | None = None
if neg_t5_embeddings is not None:
neg_bs, neg_t5_seq_len, _ = neg_t5_embeddings.shape
neg_txt_ids = torch.zeros(
neg_bs, neg_t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device()
)
# Pack all latent tensors.
init_latents = pack(init_latents) if init_latents is not None else None

View File

@@ -22,9 +22,9 @@ def denoise(
txt_ids: torch.Tensor,
vec: torch.Tensor,
# negative text conditioning
neg_txt: torch.Tensor,
neg_txt_ids: torch.Tensor,
neg_vec: torch.Tensor,
neg_txt: torch.Tensor | None,
neg_txt_ids: torch.Tensor | None,
neg_vec: torch.Tensor | None,
# sampling parameters
timesteps: list[float],
step_callback: Callable[[PipelineIntermediateState], None],
@@ -90,6 +90,10 @@ def denoise(
if not math.isclose(step_cfg_scale, 1.0):
# TODO(ryand): Add option to run positive and negative predictions in a single batch for better performance on
# systems with sufficient VRAM.
if neg_txt is None or neg_txt_ids is None or neg_vec is None:
raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.")
neg_pred = model(
img=img,
img_ids=img_ids,