fixed node issue

This commit is contained in:
Ubuntu
2025-07-17 10:31:27 +00:00
committed by Kent Keirsey
parent 8523ea88f2
commit 282df322d5
4 changed files with 23 additions and 18 deletions

View File

@@ -1,13 +1,13 @@
from invokeai.backend.bria.controlnet_bria import BRIA_CONTROL_MODES
from pydantic import BaseModel, Field
from invokeai.invocation_api import ImageOutput
from invokeai.invocation_api import ImageOutput, Classification
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType, WithBoard, WithMetadata
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.shared.invocation_context import InvocationContext
import numpy as np
@@ -26,9 +26,9 @@ class BriaControlNetField(BaseModel):
mode: BRIA_CONTROL_MODES = Field(description="The mode of the ControlNet")
conditioning_scale: float = Field(description="The weight given to the ControlNet")
@invocation_output("flux_controlnet_output")
@invocation_output("bria_controlnet_output")
class BriaControlNetOutput(BaseInvocationOutput):
"""FLUX ControlNet info"""
"""Bria ControlNet info"""
control: BriaControlNetField = OutputField(description=FieldDescriptions.control)
preprocessed_images: ImageField = OutputField(description="The preprocessed control image")
@@ -40,8 +40,9 @@ class BriaControlNetOutput(BaseInvocationOutput):
tags=["controlnet", "bria"],
category="controlnet",
version="1.0.0",
classification=Classification.Prototype,
)
class BriaControlNetInvocation(BaseInvocation):
class BriaControlNetInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Collect Bria ControlNet info to pass to denoiser node."""
control_image: ImageField = InputField(description="The control image")

View File

@@ -3,7 +3,7 @@ from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from invokeai.backend.bria.controlnet_bria import BriaControlModes, BriaMultiControlNetModel
from invokeai.backend.bria.controlnet_utils import prepare_control_images
from invokeai.backend.bria.pipeline_bria_controlnet import BriaControlNetPipeline
from invokeai.nodes.bria_nodes.bria_controlnet import BriaControlNetField
from invokeai.app.invocations.bria_controlnet import BriaControlNetField
import torch
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
@@ -12,7 +12,7 @@ from invokeai.app.invocations.fields import Input, InputField, LatentsField, Out
from invokeai.app.invocations.model import SubModelType, T5EncoderField, TransformerField, VAEField
from invokeai.app.invocations.primitives import BaseInvocationOutput, FieldDescriptions
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.invocation_api import BaseInvocation, Classification, InputField, invocation, invocation_output
from invokeai.invocation_api import BaseInvocation, Classification, invocation, invocation_output
from invokeai.backend.bria.transformer_bria import BriaTransformer2DModel
@invocation_output("bria_denoise_output")

View File

@@ -48,18 +48,22 @@ class BriaLatentSamplerInvocation(BaseInvocation):
title="Transformer",
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> BriaLatentSamplerInvocationOutput:
device = torch.device("cuda")
with context.models.load(self.transformer.transformer) as transformer:
device = transformer.device
dtype = transformer.dtype
height, width = 1024, 1024
generator = torch.Generator(device=device).manual_seed(self.seed)
num_channels_latents = 4 # due to patch=2, we devide by 4
num_channels_latents = 4
latents, latent_image_ids = prepare_latents(
batch_size=1,
num_channels_latents=num_channels_latents,
height=height,
width=width,
dtype=torch.float32,
dtype=dtype,
device=device,
generator=generator,
)

View File

@@ -612,14 +612,14 @@ def encode_prompt(
def prepare_latents(
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
batch_size: int,
num_channels_latents: int,
height: int,
width: int,
dtype: torch.dtype,
device: torch.device,
generator: torch.Generator,
latents: Optional[torch.FloatTensor] = None,
):
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.