Fix circular imports related to XLabsControlNetFluxOutput and InstantXControlNetFluxOutput.

This commit is contained in:
Ryan Dick
2024-10-08 16:05:34 +00:00
committed by Kent Keirsey
parent 4289b5e6c3
commit 47c7df3476
9 changed files with 41 additions and 36 deletions

View File

@@ -1,11 +1,11 @@
# This file was initially copied from:
# https://github.com/huggingface/diffusers/blob/99f608218caa069a2f16dcf9efab46959b15aec0/src/diffusers/models/controlnet_flux.py
from dataclasses import dataclass
import torch
import torch.nn as nn
from invokeai.backend.flux.controlnet.instantx_controlnet_flux_output import InstantXControlNetFluxOutput
from invokeai.backend.flux.controlnet.zero_module import zero_module
from invokeai.backend.flux.model import FluxParams
from invokeai.backend.flux.modules.layers import (
@@ -16,21 +16,6 @@ from invokeai.backend.flux.modules.layers import (
timestep_embedding,
)
@dataclass
class InstantXControlNetFluxOutput:
controlnet_block_samples: list[torch.Tensor] | None
controlnet_single_block_samples: list[torch.Tensor] | None
def apply_weight(self, weight: float):
if self.controlnet_block_samples is not None:
for i in range(len(self.controlnet_block_samples)):
self.controlnet_block_samples[i] = self.controlnet_block_samples[i] * weight
if self.controlnet_single_block_samples is not None:
for i in range(len(self.controlnet_single_block_samples)):
self.controlnet_single_block_samples[i] = self.controlnet_single_block_samples[i] * weight
# NOTE(ryand): Mapping between diffusers FLUX transformer params and BFL FLUX transformer params:
# - Diffusers: BFL
# - in_channels: in_channels

View File

@@ -0,0 +1,17 @@
from dataclasses import dataclass
import torch
@dataclass
class InstantXControlNetFluxOutput:
controlnet_block_samples: list[torch.Tensor] | None
controlnet_single_block_samples: list[torch.Tensor] | None
def apply_weight(self, weight: float):
if self.controlnet_block_samples is not None:
for i in range(len(self.controlnet_block_samples)):
self.controlnet_block_samples[i] = self.controlnet_block_samples[i] * weight
if self.controlnet_single_block_samples is not None:
for i in range(len(self.controlnet_single_block_samples)):
self.controlnet_single_block_samples[i] = self.controlnet_single_block_samples[i] * weight

View File

@@ -2,26 +2,15 @@
# https://github.com/XLabs-AI/x-flux/blob/47495425dbed499be1e8e5a6e52628b07349cba2/src/flux/controlnet.py
from dataclasses import dataclass
import torch
from einops import rearrange
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux_output import XLabsControlNetFluxOutput
from invokeai.backend.flux.controlnet.zero_module import zero_module
from invokeai.backend.flux.model import FluxParams
from invokeai.backend.flux.modules.layers import DoubleStreamBlock, EmbedND, MLPEmbedder, timestep_embedding
@dataclass
class XLabsControlNetFluxOutput:
controlnet_double_block_residuals: list[torch.Tensor] | None
def apply_weight(self, weight: float):
if self.controlnet_double_block_residuals is not None:
for i in range(len(self.controlnet_double_block_residuals)):
self.controlnet_double_block_residuals[i] = self.controlnet_double_block_residuals[i] * weight
class XLabsControlNetFlux(torch.nn.Module):
"""A ControlNet model for FLUX.

View File

@@ -0,0 +1,13 @@
from dataclasses import dataclass
import torch
@dataclass
class XLabsControlNetFluxOutput:
controlnet_double_block_residuals: list[torch.Tensor] | None
def apply_weight(self, weight: float):
if self.controlnet_double_block_residuals is not None:
for i in range(len(self.controlnet_double_block_residuals)):
self.controlnet_double_block_residuals[i] = self.controlnet_double_block_residuals[i] * weight

View File

@@ -4,8 +4,8 @@ from typing import Callable
import torch
from tqdm import tqdm
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFluxOutput
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFluxOutput
from invokeai.backend.flux.controlnet.instantx_controlnet_flux_output import InstantXControlNetFluxOutput
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux_output import XLabsControlNetFluxOutput
from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension

View File

@@ -4,8 +4,8 @@ from typing import List, Union
import torch
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFluxOutput
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFluxOutput
from invokeai.backend.flux.controlnet.instantx_controlnet_flux_output import InstantXControlNetFluxOutput
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux_output import XLabsControlNetFluxOutput
class BaseControlNetExtension(ABC):

View File

@@ -8,8 +8,8 @@ from invokeai.app.invocations.flux_vae_encode import FluxVaeEncodeInvocation
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, prepare_control_image
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import (
InstantXControlNetFlux,
InstantXControlNetFluxOutput,
)
from invokeai.backend.flux.controlnet.instantx_controlnet_flux_output import InstantXControlNetFluxOutput
from invokeai.backend.flux.extensions.base_controlnet_extension import BaseControlNetExtension
from invokeai.backend.model_manager.load.load_base import LoadedModel

View File

@@ -5,7 +5,8 @@ from PIL.Image import Image
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, prepare_control_image
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux, XLabsControlNetFluxOutput
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux_output import XLabsControlNetFluxOutput
from invokeai.backend.flux.extensions.base_controlnet_extension import BaseControlNetExtension

View File

@@ -6,8 +6,8 @@ from dataclasses import dataclass
import torch
from torch import Tensor, nn
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFluxOutput
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFluxOutput
from invokeai.backend.flux.controlnet.instantx_controlnet_flux_output import InstantXControlNetFluxOutput
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux_output import XLabsControlNetFluxOutput
from invokeai.backend.flux.modules.layers import (
DoubleStreamBlock,
EmbedND,