mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
fixed node issue
This commit is contained in:
@@ -1,13 +1,13 @@
|
||||
from invokeai.backend.bria.controlnet_bria import BRIA_CONTROL_MODES
|
||||
from pydantic import BaseModel, Field
|
||||
from invokeai.invocation_api import ImageOutput
|
||||
from invokeai.invocation_api import ImageOutput, Classification
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType, WithBoard, WithMetadata
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
import numpy as np
|
||||
@@ -26,9 +26,9 @@ class BriaControlNetField(BaseModel):
|
||||
mode: BRIA_CONTROL_MODES = Field(description="The mode of the ControlNet")
|
||||
conditioning_scale: float = Field(description="The weight given to the ControlNet")
|
||||
|
||||
@invocation_output("flux_controlnet_output")
|
||||
@invocation_output("bria_controlnet_output")
|
||||
class BriaControlNetOutput(BaseInvocationOutput):
|
||||
"""FLUX ControlNet info"""
|
||||
"""Bria ControlNet info"""
|
||||
|
||||
control: BriaControlNetField = OutputField(description=FieldDescriptions.control)
|
||||
preprocessed_images: ImageField = OutputField(description="The preprocessed control image")
|
||||
@@ -40,8 +40,9 @@ class BriaControlNetOutput(BaseInvocationOutput):
|
||||
tags=["controlnet", "bria"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class BriaControlNetInvocation(BaseInvocation):
|
||||
class BriaControlNetInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Collect Bria ControlNet info to pass to denoiser node."""
|
||||
|
||||
control_image: ImageField = InputField(description="The control image")
|
||||
|
||||
@@ -3,7 +3,7 @@ from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from invokeai.backend.bria.controlnet_bria import BriaControlModes, BriaMultiControlNetModel
|
||||
from invokeai.backend.bria.controlnet_utils import prepare_control_images
|
||||
from invokeai.backend.bria.pipeline_bria_controlnet import BriaControlNetPipeline
|
||||
from invokeai.nodes.bria_nodes.bria_controlnet import BriaControlNetField
|
||||
from invokeai.app.invocations.bria_controlnet import BriaControlNetField
|
||||
|
||||
import torch
|
||||
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
|
||||
@@ -12,7 +12,7 @@ from invokeai.app.invocations.fields import Input, InputField, LatentsField, Out
|
||||
from invokeai.app.invocations.model import SubModelType, T5EncoderField, TransformerField, VAEField
|
||||
from invokeai.app.invocations.primitives import BaseInvocationOutput, FieldDescriptions
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.invocation_api import BaseInvocation, Classification, InputField, invocation, invocation_output
|
||||
from invokeai.invocation_api import BaseInvocation, Classification, invocation, invocation_output
|
||||
from invokeai.backend.bria.transformer_bria import BriaTransformer2DModel
|
||||
|
||||
@invocation_output("bria_denoise_output")
|
||||
|
||||
@@ -48,18 +48,22 @@ class BriaLatentSamplerInvocation(BaseInvocation):
|
||||
title="Transformer",
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> BriaLatentSamplerInvocationOutput:
|
||||
device = torch.device("cuda")
|
||||
with context.models.load(self.transformer.transformer) as transformer:
|
||||
device = transformer.device
|
||||
dtype = transformer.dtype
|
||||
|
||||
height, width = 1024, 1024
|
||||
generator = torch.Generator(device=device).manual_seed(self.seed)
|
||||
|
||||
num_channels_latents = 4 # due to patch=2, we devide by 4
|
||||
num_channels_latents = 4
|
||||
latents, latent_image_ids = prepare_latents(
|
||||
batch_size=1,
|
||||
num_channels_latents=num_channels_latents,
|
||||
height=height,
|
||||
width=width,
|
||||
dtype=torch.float32,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
@@ -612,14 +612,14 @@ def encode_prompt(
|
||||
|
||||
|
||||
def prepare_latents(
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
batch_size: int,
|
||||
num_channels_latents: int,
|
||||
height: int,
|
||||
width: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
generator: torch.Generator,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
# VAE applies 8x compression on images but we must also account for packing which requires
|
||||
# latent height and width to be divisible by 2.
|
||||
|
||||
Reference in New Issue
Block a user