mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Update FluxControlnetModel to work with both XLabs and InstantX.
This commit is contained in:
@@ -10,6 +10,14 @@ from safetensors.torch import load_file
|
||||
from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.flux.controlnet.diffusers_controlnet_flux import DiffusersControlNetFlux
|
||||
from invokeai.backend.flux.controlnet.state_dict_utils import (
|
||||
convert_diffusers_instantx_state_dict_to_bfl_format,
|
||||
infer_flux_params_from_state_dict,
|
||||
infer_instantx_num_control_modes_from_state_dict,
|
||||
is_state_dict_instantx_controlnet,
|
||||
is_state_dict_xlabs_controlnet,
|
||||
)
|
||||
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
@@ -309,10 +317,31 @@ class FluxControlnetModel(ModelLoader):
|
||||
assert isinstance(config, ControlNetCheckpointConfig)
|
||||
model_path = Path(config.path)
|
||||
|
||||
sd = load_file(model_path)
|
||||
|
||||
# Detect the FLUX ControlNet model type from the state dict.
|
||||
if is_state_dict_xlabs_controlnet(sd):
|
||||
return self._load_xlabs_controlnet(sd)
|
||||
elif is_state_dict_instantx_controlnet(sd):
|
||||
return self._load_instantx_controlnet(sd)
|
||||
else:
|
||||
raise ValueError("Do not recognize the state dict as an XLabs or InstantX ControlNet model.")
|
||||
|
||||
def _load_xlabs_controlnet(self, sd: dict[str, torch.Tensor]) -> AnyModel:
|
||||
with accelerate.init_empty_weights():
|
||||
# HACK(ryand): Is it safe to assume dev here?
|
||||
model = XLabsControlNetFlux(params["flux-dev"])
|
||||
|
||||
sd = load_file(model_path)
|
||||
model.load_state_dict(sd, assign=True)
|
||||
return model
|
||||
|
||||
def _load_instantx_controlnet(self, sd: dict[str, torch.Tensor]) -> AnyModel:
|
||||
sd = convert_diffusers_instantx_state_dict_to_bfl_format(sd)
|
||||
flux_params = infer_flux_params_from_state_dict(sd)
|
||||
num_control_modes = infer_instantx_num_control_modes_from_state_dict(sd)
|
||||
|
||||
with accelerate.init_empty_weights():
|
||||
model = DiffusersControlNetFlux(flux_params, num_control_modes)
|
||||
|
||||
model.load_state_dict(sd, assign=True)
|
||||
return model
|
||||
|
||||
Reference in New Issue
Block a user