# Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team """Class for Flux model loading in InvokeAI.""" from pathlib import Path from typing import Optional import accelerate import torch from safetensors.torch import load_file from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer from invokeai.app.services.config.config_default import get_config from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux from invokeai.backend.flux.controlnet.state_dict_utils import ( convert_diffusers_instantx_state_dict_to_bfl_format, infer_flux_params_from_state_dict, infer_instantx_num_control_modes_from_state_dict, is_state_dict_instantx_controlnet, is_state_dict_xlabs_controlnet, ) from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux from invokeai.backend.flux.ip_adapter.state_dict_utils import infer_xlabs_ip_adapter_params_from_state_dict from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import ( XlabsIpAdapterFlux, ) from invokeai.backend.flux.model import Flux from invokeai.backend.flux.modules.autoencoder import AutoEncoder from invokeai.backend.flux.util import ae_params, params from invokeai.backend.model_manager import ( AnyModel, AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType, ) from invokeai.backend.model_manager.config import ( CheckpointConfigBase, CLIPEmbedDiffusersConfig, ControlNetCheckpointConfig, ControlNetDiffusersConfig, IPAdapterCheckpointConfig, MainBnbQuantized4bCheckpointConfig, MainCheckpointConfig, MainGGUFCheckpointConfig, T5EncoderBnbQuantizedLlmInt8bConfig, T5EncoderConfig, VAECheckpointConfig, ) from invokeai.backend.model_manager.load.load_default import ModelLoader from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry from invokeai.backend.model_manager.util.model_util import ( convert_bundle_to_flux_transformer_checkpoint, ) from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader from invokeai.backend.quantization.gguf.utils import TORCH_COMPATIBLE_QTYPES from invokeai.backend.util.silence_warnings import SilenceWarnings try: from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8 from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4 bnb_available = True except ImportError: bnb_available = False app_config = get_config() @ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.VAE, format=ModelFormat.Checkpoint) class FluxVAELoader(ModelLoader): """Class to load VAE models.""" def _load_model( self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, ) -> AnyModel: if not isinstance(config, VAECheckpointConfig): raise ValueError("Only VAECheckpointConfig models are currently supported here.") model_path = Path(config.path) with accelerate.init_empty_weights(): model = AutoEncoder(ae_params[config.config_path]) sd = load_file(model_path) model.load_state_dict(sd, assign=True) # VAE is broken in float16, which mps defaults to if self._torch_dtype == torch.float16: try: vae_dtype = torch.tensor([1.0], dtype=torch.bfloat16, device=self._torch_device).dtype except TypeError: vae_dtype = torch.float32 else: vae_dtype = self._torch_dtype model.to(vae_dtype) return model @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPEmbed, format=ModelFormat.Diffusers) class ClipCheckpointModel(ModelLoader): """Class to load main models.""" def _load_model( self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, ) -> AnyModel: if not isinstance(config, CLIPEmbedDiffusersConfig): raise ValueError("Only CLIPEmbedDiffusersConfig models are currently supported here.") match submodel_type: case SubModelType.Tokenizer: return CLIPTokenizer.from_pretrained(Path(config.path) / "tokenizer") case SubModelType.TextEncoder: return CLIPTextModel.from_pretrained(Path(config.path) / "text_encoder") raise ValueError( f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}" ) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.BnbQuantizedLlmInt8b) class BnbQuantizedLlmInt8bCheckpointModel(ModelLoader): """Class to load main models.""" def _load_model( self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, ) -> AnyModel: if not isinstance(config, T5EncoderBnbQuantizedLlmInt8bConfig): raise ValueError("Only T5EncoderBnbQuantizedLlmInt8bConfig models are currently supported here.") if not bnb_available: raise ImportError( "The bnb modules are not available. Please install bitsandbytes if available on your platform." ) match submodel_type: case SubModelType.Tokenizer2 | SubModelType.Tokenizer3: return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512) case SubModelType.TextEncoder2 | SubModelType.TextEncoder3: te2_model_path = Path(config.path) / "text_encoder_2" model_config = AutoConfig.from_pretrained(te2_model_path) with accelerate.init_empty_weights(): model = AutoModelForTextEncoding.from_config(model_config) model = quantize_model_llm_int8(model, modules_to_not_convert=set()) state_dict_path = te2_model_path / "bnb_llm_int8_model.safetensors" state_dict = load_file(state_dict_path) self._load_state_dict_into_t5(model, state_dict) return model raise ValueError( f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}" ) @classmethod def _load_state_dict_into_t5(cls, model: T5EncoderModel, state_dict: dict[str, torch.Tensor]): # There is a shared reference to a single weight tensor in the model. # Both "encoder.embed_tokens.weight" and "shared.weight" refer to the same tensor, so only the latter should # be present in the state_dict. missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False, assign=True) assert len(unexpected_keys) == 0 assert set(missing_keys) == {"encoder.embed_tokens.weight"} # Assert that the layers we expect to be shared are actually shared. assert model.encoder.embed_tokens.weight is model.shared.weight @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder) class T5EncoderCheckpointModel(ModelLoader): """Class to load main models.""" def _load_model( self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, ) -> AnyModel: if not isinstance(config, T5EncoderConfig): raise ValueError("Only T5EncoderConfig models are currently supported here.") match submodel_type: case SubModelType.Tokenizer2 | SubModelType.Tokenizer3: return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512) case SubModelType.TextEncoder2 | SubModelType.TextEncoder3: return T5EncoderModel.from_pretrained( Path(config.path) / "text_encoder_2", torch_dtype="auto", low_cpu_mem_usage=True ) raise ValueError( f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}" ) @ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.Checkpoint) class FluxCheckpointModel(ModelLoader): """Class to load main models.""" def _load_model( self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, ) -> AnyModel: if not isinstance(config, CheckpointConfigBase): raise ValueError("Only CheckpointConfigBase models are currently supported here.") match submodel_type: case SubModelType.Transformer: return self._load_from_singlefile(config) raise ValueError( f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}" ) def _load_from_singlefile( self, config: AnyModelConfig, ) -> AnyModel: assert isinstance(config, MainCheckpointConfig) model_path = Path(config.path) with accelerate.init_empty_weights(): model = Flux(params[config.config_path]) sd = load_file(model_path) if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd: sd = convert_bundle_to_flux_transformer_checkpoint(sd) new_sd_size = sum([ten.nelement() * torch.bfloat16.itemsize for ten in sd.values()]) self._ram_cache.make_room(new_sd_size) for k in sd.keys(): # We need to cast to bfloat16 due to it being the only currently supported dtype for inference sd[k] = sd[k].to(torch.bfloat16) model.load_state_dict(sd, assign=True) return model @ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.GGUFQuantized) class FluxGGUFCheckpointModel(ModelLoader): """Class to load GGUF main models.""" def _load_model( self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, ) -> AnyModel: if not isinstance(config, CheckpointConfigBase): raise ValueError("Only CheckpointConfigBase models are currently supported here.") match submodel_type: case SubModelType.Transformer: return self._load_from_singlefile(config) raise ValueError( f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}" ) def _load_from_singlefile( self, config: AnyModelConfig, ) -> AnyModel: assert isinstance(config, MainGGUFCheckpointConfig) model_path = Path(config.path) with accelerate.init_empty_weights(): model = Flux(params[config.config_path]) # HACK(ryand): We shouldn't be hard-coding the compute_dtype here. sd = gguf_sd_loader(model_path, compute_dtype=torch.bfloat16) # HACK(ryand): There are some broken GGUF models in circulation that have the wrong shape for img_in.weight. # We override the shape here to fix the issue. # Example model with this issue (Q4_K_M): https://civitai.com/models/705823/ggufk-flux-unchained-km-quants img_in_weight = sd.get("img_in.weight", None) if img_in_weight is not None and img_in_weight._ggml_quantization_type in TORCH_COMPATIBLE_QTYPES: expected_img_in_weight_shape = model.img_in.weight.shape img_in_weight.quantized_data = img_in_weight.quantized_data.view(expected_img_in_weight_shape) img_in_weight.tensor_shape = expected_img_in_weight_shape model.load_state_dict(sd, assign=True) return model @ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.BnbQuantizednf4b) class FluxBnbQuantizednf4bCheckpointModel(ModelLoader): """Class to load main models.""" def _load_model( self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, ) -> AnyModel: if not isinstance(config, CheckpointConfigBase): raise ValueError("Only CheckpointConfigBase models are currently supported here.") match submodel_type: case SubModelType.Transformer: return self._load_from_singlefile(config) raise ValueError( f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}" ) def _load_from_singlefile( self, config: AnyModelConfig, ) -> AnyModel: assert isinstance(config, MainBnbQuantized4bCheckpointConfig) if not bnb_available: raise ImportError( "The bnb modules are not available. Please install bitsandbytes if available on your platform." ) model_path = Path(config.path) with SilenceWarnings(): with accelerate.init_empty_weights(): model = Flux(params[config.config_path]) model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16) sd = load_file(model_path) if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd: sd = convert_bundle_to_flux_transformer_checkpoint(sd) model.load_state_dict(sd, assign=True) return model @ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.ControlNet, format=ModelFormat.Checkpoint) @ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.ControlNet, format=ModelFormat.Diffusers) class FluxControlnetModel(ModelLoader): """Class to load FLUX ControlNet models.""" def _load_model( self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, ) -> AnyModel: if isinstance(config, ControlNetCheckpointConfig): model_path = Path(config.path) elif isinstance(config, ControlNetDiffusersConfig): # If this is a diffusers directory, we simply ignore the config file and load from the weight file. model_path = Path(config.path) / "diffusion_pytorch_model.safetensors" else: raise ValueError(f"Unexpected ControlNet model config type: {type(config)}") sd = load_file(model_path) # Detect the FLUX ControlNet model type from the state dict. if is_state_dict_xlabs_controlnet(sd): return self._load_xlabs_controlnet(sd) elif is_state_dict_instantx_controlnet(sd): return self._load_instantx_controlnet(sd) else: raise ValueError("Do not recognize the state dict as an XLabs or InstantX ControlNet model.") def _load_xlabs_controlnet(self, sd: dict[str, torch.Tensor]) -> AnyModel: with accelerate.init_empty_weights(): # HACK(ryand): Is it safe to assume dev here? model = XLabsControlNetFlux(params["flux-dev"]) model.load_state_dict(sd, assign=True) return model def _load_instantx_controlnet(self, sd: dict[str, torch.Tensor]) -> AnyModel: sd = convert_diffusers_instantx_state_dict_to_bfl_format(sd) flux_params = infer_flux_params_from_state_dict(sd) num_control_modes = infer_instantx_num_control_modes_from_state_dict(sd) with accelerate.init_empty_weights(): model = InstantXControlNetFlux(flux_params, num_control_modes) model.load_state_dict(sd, assign=True) return model @ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.IPAdapter, format=ModelFormat.Checkpoint) class FluxIpAdapterModel(ModelLoader): """Class to load FLUX IP-Adapter models.""" def _load_model( self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, ) -> AnyModel: if not isinstance(config, IPAdapterCheckpointConfig): raise ValueError(f"Unexpected model config type: {type(config)}.") sd = load_file(Path(config.path)) params = infer_xlabs_ip_adapter_params_from_state_dict(sd) with accelerate.init_empty_weights(): model = XlabsIpAdapterFlux(params=params) model.load_xlabs_state_dict(sd, assign=True) return model