Compare commits

...

10 Commits

10 changed files with 482 additions and 8 deletions

View File

@@ -6,6 +6,7 @@ import torchvision.transforms as tv_transforms
from torchvision.transforms.functional import resize as tv_resize
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.fields import (
DenoiseMaskField,
FieldDescriptions,
@@ -19,6 +20,8 @@ from invokeai.app.invocations.fields import (
from invokeai.app.invocations.model import TransformerField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.controlnet.controlnet_flux import ControlNetFlux
from invokeai.backend.flux.controlnet_extension import ControlNetExtension
from invokeai.backend.flux.denoise import denoise
from invokeai.backend.flux.inpaint_extension import InpaintExtension
from invokeai.backend.flux.model import Flux
@@ -44,7 +47,7 @@ from invokeai.backend.util.devices import TorchDevice
title="FLUX Denoise",
tags=["image", "flux"],
category="image",
version="3.0.0",
version="3.1.0",
classification=Classification.Prototype,
)
class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
@@ -87,6 +90,9 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.",
)
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
controlnet: ControlField | list[ControlField] | None = InputField(
default=None, input=Input.Connection, description="ControlNet models."
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
@@ -167,8 +173,8 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
inpaint_mask = self._prep_inpaint_mask(context, x)
b, _c, h, w = x.shape
img_ids = generate_img_ids(h=h, w=w, batch_size=b, device=x.device, dtype=x.dtype)
b, _c, latent_h, latent_w = x.shape
img_ids = generate_img_ids(h=latent_h, w=latent_w, batch_size=b, device=x.device, dtype=x.dtype)
bs, t5_seq_len, _ = t5_embeddings.shape
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
@@ -231,6 +237,16 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
else:
raise ValueError(f"Unsupported model format: {config.format}")
# Prepare ControlNet extensions.
controlnet_extensions = self._prep_controlnet_extensions(
context=context,
exit_stack=exit_stack,
latent_height=latent_h,
latent_width=latent_w,
dtype=inference_dtype,
device=x.device,
)
x = denoise(
model=transformer,
img=x,
@@ -242,6 +258,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
step_callback=self._build_step_callback(context),
guidance=self.guidance,
inpaint_extension=inpaint_extension,
controlnet_extensions=controlnet_extensions,
)
x = unpack(x.float(), self.height, self.width)
@@ -288,6 +305,50 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
# `latents`.
return mask.expand_as(latents)
def _prep_controlnet_extensions(
self,
context: InvocationContext,
exit_stack: ExitStack,
latent_height: int,
latent_width: int,
dtype: torch.dtype,
device: torch.device,
) -> list[ControlNetExtension] | None:
# Normalize the controlnet input to list[ControlField].
controlnets: list[ControlField]
if self.controlnet is None:
return None
elif isinstance(self.controlnet, ControlField):
controlnets = [self.controlnet]
elif isinstance(self.controlnet, list):
controlnets = self.controlnet
else:
raise ValueError(f"Unsupported controlnet type: {type(self.controlnet)}")
controlnet_extensions: list[ControlNetExtension] = []
for controlnet in controlnets:
model = exit_stack.enter_context(context.models.load(controlnet.control_model))
assert isinstance(model, ControlNetFlux)
image = context.images.get_pil(controlnet.image.image_name)
controlnet_extensions.append(
ControlNetExtension.from_controlnet_image(
model=model,
controlnet_image=image,
latent_height=latent_height,
latent_width=latent_width,
dtype=dtype,
device=device,
control_mode=controlnet.control_mode,
resize_mode=controlnet.resize_mode,
weight=controlnet.control_weight,
begin_step_percent=controlnet.begin_step_percent,
end_step_percent=controlnet.end_step_percent,
)
)
return controlnet_extensions
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.transformer.loras:
lora_info = context.models.load(lora.lora)

View File

@@ -0,0 +1,130 @@
# This file was initially based on:
# https://github.com/XLabs-AI/x-flux/blob/47495425dbed499be1e8e5a6e52628b07349cba2/src/flux/controlnet.py
import torch
from einops import rearrange
from torch import Tensor, nn
from invokeai.backend.flux.model import FluxParams
from invokeai.backend.flux.modules.layers import DoubleStreamBlock, EmbedND, MLPEmbedder, timestep_embedding
def _zero_module(module: torch.nn.Module) -> torch.nn.Module:
"""Initialize the parameters of a module to zero."""
for p in module.parameters():
nn.init.zeros_(p)
return module
class ControlNetFlux(nn.Module):
"""A ControlNet model for FLUX.
The architecture is very similar to the base FLUX model, with the following differences:
- A `controlnet_depth` parameter is passed to control the number of double_blocks that the ControlNet is applied to.
In order to keep the ControlNet small, this is typically much less than the depth of the base FLUX model.
- There is a set of `controlnet_blocks` that are applied to the output of each double_block.
"""
def __init__(self, params: FluxParams, controlnet_depth: int = 2):
super().__init__()
self.params = params
self.in_channels = params.in_channels
self.out_channels = self.in_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
)
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
)
for _ in range(controlnet_depth)
]
)
# Add ControlNet blocks.
self.controlnet_blocks = nn.ModuleList([])
for _ in range(controlnet_depth):
controlnet_block = nn.Linear(self.hidden_size, self.hidden_size)
controlnet_block = _zero_module(controlnet_block)
self.controlnet_blocks.append(controlnet_block)
self.pos_embed_input = nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.input_hint_block = nn.Sequential(
nn.Conv2d(3, 16, 3, padding=1),
nn.SiLU(),
nn.Conv2d(16, 16, 3, padding=1),
nn.SiLU(),
nn.Conv2d(16, 16, 3, padding=1, stride=2),
nn.SiLU(),
nn.Conv2d(16, 16, 3, padding=1),
nn.SiLU(),
nn.Conv2d(16, 16, 3, padding=1, stride=2),
nn.SiLU(),
nn.Conv2d(16, 16, 3, padding=1),
nn.SiLU(),
nn.Conv2d(16, 16, 3, padding=1, stride=2),
nn.SiLU(),
_zero_module(nn.Conv2d(16, 16, 3, padding=1)),
)
def forward(
self,
img: Tensor,
img_ids: Tensor,
controlnet_cond: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor | None = None,
) -> list[Tensor]:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
# running on sequences img
img = self.img_in(img)
controlnet_cond = self.input_hint_block(controlnet_cond)
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
controlnet_cond = self.pos_embed_input(controlnet_cond)
img = img + controlnet_cond
vec = self.time_in(timestep_embedding(timesteps, 256))
if self.params.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
block_res_samples: list[torch.Tensor] = []
for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
block_res_samples.append(img)
controlnet_block_res_samples: list[torch.Tensor] = []
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks, strict=True):
block_res_sample = controlnet_block(block_res_sample)
controlnet_block_res_samples.append(block_res_sample)
return controlnet_block_res_samples

View File

@@ -0,0 +1,103 @@
import math
from typing import List, Union
import torch
from PIL.Image import Image
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, prepare_control_image
from invokeai.backend.flux.controlnet.controlnet_flux import ControlNetFlux
class ControlNetExtension:
def __init__(
self,
model: ControlNetFlux,
controlnet_cond: torch.Tensor,
weight: Union[float, List[float]],
begin_step_percent: float,
end_step_percent: float,
):
self._model = model
# _controlnet_cond is the control image passed to the ControlNet model.
# Pixel values are in the range [-1, 1]. Shape: (batch_size, 3, height, width).
self._controlnet_cond = controlnet_cond
self._weight = weight
self._begin_step_percent = begin_step_percent
self._end_step_percent = end_step_percent
@classmethod
def from_controlnet_image(
cls,
model: ControlNetFlux,
controlnet_image: Image,
latent_height: int,
latent_width: int,
dtype: torch.dtype,
device: torch.device,
control_mode: CONTROLNET_MODE_VALUES,
resize_mode: CONTROLNET_RESIZE_VALUES,
weight: Union[float, List[float]],
begin_step_percent: float,
end_step_percent: float,
):
image_height = latent_height * LATENT_SCALE_FACTOR
image_width = latent_width * LATENT_SCALE_FACTOR
controlnet_cond = prepare_control_image(
image=controlnet_image,
do_classifier_free_guidance=False,
width=image_width,
height=image_height,
device=device,
dtype=dtype,
control_mode=control_mode,
resize_mode=resize_mode,
)
# Map pixel values from [0, 1] to [-1, 1].
controlnet_cond = controlnet_cond * 2 - 1
return cls(
model=model,
controlnet_cond=controlnet_cond,
weight=weight,
begin_step_percent=begin_step_percent,
end_step_percent=end_step_percent,
)
def run_controlnet(
self,
timestep_index: int,
total_num_timesteps: int,
img: torch.Tensor,
img_ids: torch.Tensor,
txt: torch.Tensor,
txt_ids: torch.Tensor,
y: torch.Tensor,
timesteps: torch.Tensor,
guidance: torch.Tensor | None,
) -> list[torch.Tensor] | None:
first_step = math.floor(self._begin_step_percent * total_num_timesteps)
last_step = math.ceil(self._end_step_percent * total_num_timesteps)
if timestep_index < first_step or timestep_index > last_step:
return
weight = self._weight
controlnet_block_res_samples = self._model(
img=img,
img_ids=img_ids,
controlnet_cond=self._controlnet_cond,
txt=txt,
txt_ids=txt_ids,
timesteps=timesteps,
y=y,
guidance=guidance,
)
# Apply weight to the residuals.
for block_res_sample in controlnet_block_res_samples:
block_res_sample *= weight
return controlnet_block_res_samples

View File

@@ -3,6 +3,7 @@ from typing import Callable
import torch
from tqdm import tqdm
from invokeai.backend.flux.controlnet_extension import ControlNetExtension
from invokeai.backend.flux.inpaint_extension import InpaintExtension
from invokeai.backend.flux.model import Flux
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
@@ -21,6 +22,7 @@ def denoise(
step_callback: Callable[[PipelineIntermediateState], None],
guidance: float,
inpaint_extension: InpaintExtension | None,
controlnet_extensions: list[ControlNetExtension] | None,
):
# step 0 is the initial state
total_steps = len(timesteps) - 1
@@ -38,6 +40,25 @@ def denoise(
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
# Run ControlNet models.
# controlnet_block_residuals[i][j] is the residual of the j-th block of the i-th ControlNet model.
controlnet_block_residuals: list[list[torch.Tensor] | None] = []
for controlnet_extension in controlnet_extensions or []:
controlnet_block_residuals.append(
controlnet_extension.run_controlnet(
timestep_index=step - 1,
total_num_timesteps=total_steps,
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
)
)
pred = model(
img=img,
img_ids=img_ids,
@@ -46,6 +67,7 @@ def denoise(
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
controlnet_block_residuals=controlnet_block_residuals,
)
preview_img = img - t_curr * pred

View File

@@ -88,6 +88,7 @@ class Flux(nn.Module):
timesteps: Tensor,
y: Tensor,
guidance: Tensor | None = None,
controlnet_block_residuals: list[list[Tensor] | None] | None = None,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
@@ -105,9 +106,15 @@ class Flux(nn.Module):
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
for block in self.double_blocks:
for block_index, block in enumerate(self.double_blocks):
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
# Apply ControlNet residuals.
if controlnet_block_residuals is not None:
for single_controlnet_block_residuals in controlnet_block_residuals:
if single_controlnet_block_residuals:
img += single_controlnet_block_residuals[block_index % len(single_controlnet_block_residuals)]
img = torch.cat((txt, img), 1)
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe)

View File

@@ -8,17 +8,36 @@ from diffusers import ControlNetModel
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
)
from invokeai.backend.model_manager.config import (
BaseModelType,
ControlNetCheckpointConfig,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.config import ControlNetCheckpointConfig, SubModelType
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(
base=BaseModelType.StableDiffusion1, type=ModelType.ControlNet, format=ModelFormat.Diffusers
)
@ModelLoaderRegistry.register(
base=BaseModelType.StableDiffusion1, type=ModelType.ControlNet, format=ModelFormat.Checkpoint
)
@ModelLoaderRegistry.register(
base=BaseModelType.StableDiffusion2, type=ModelType.ControlNet, format=ModelFormat.Diffusers
)
@ModelLoaderRegistry.register(
base=BaseModelType.StableDiffusion2, type=ModelType.ControlNet, format=ModelFormat.Checkpoint
)
@ModelLoaderRegistry.register(
base=BaseModelType.StableDiffusionXL, type=ModelType.ControlNet, format=ModelFormat.Diffusers
)
@ModelLoaderRegistry.register(
base=BaseModelType.StableDiffusionXL, type=ModelType.ControlNet, format=ModelFormat.Checkpoint
)
class ControlNetLoader(GenericDiffusersLoader):
"""Class to load ControlNet models."""

View File

@@ -10,6 +10,7 @@ from safetensors.torch import load_file
from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.flux.controlnet.controlnet_flux import ControlNetFlux
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.flux.util import ae_params, params
@@ -24,6 +25,7 @@ from invokeai.backend.model_manager import (
from invokeai.backend.model_manager.config import (
CheckpointConfigBase,
CLIPEmbedDiffusersConfig,
ControlNetCheckpointConfig,
MainBnbQuantized4bCheckpointConfig,
MainCheckpointConfig,
MainGGUFCheckpointConfig,
@@ -293,3 +295,24 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
sd = convert_bundle_to_flux_transformer_checkpoint(sd)
model.load_state_dict(sd, assign=True)
return model
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.ControlNet, format=ModelFormat.Checkpoint)
class FluxControlnetModel(ModelLoader):
"""Class to load FLUX ControlNet models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
assert isinstance(config, ControlNetCheckpointConfig)
model_path = Path(config.path)
with accelerate.init_empty_weights():
# HACK(ryand): Is it safe to assume dev here?
model = ControlNetFlux(params["flux-dev"])
sd = load_file(model_path)
model.load_state_dict(sd, assign=True)
return model

View File

@@ -255,7 +255,19 @@ class ModelProbe(object):
# LoRA models, but as of the time of writing, we support Diffusers FLUX PEFT LoRA models.
elif key.endswith(("to_k_lora.up.weight", "to_q_lora.down.weight", "lora_A.weight", "lora_B.weight")):
return ModelType.LoRA
elif key.startswith(("controlnet", "control_model", "input_blocks")):
elif key.startswith(
(
"controlnet",
"control_model",
"input_blocks",
# XLabs FLUX ControlNet models have keys starting with "controlnet_blocks."
# For example: https://huggingface.co/XLabs-AI/flux-controlnet-collections/blob/86ab1e915a389d5857135c00e0d350e9e38a9048/flux-canny-controlnet_v2.safetensors
# TODO(ryand): This is very fragile. XLabs FLUX ControlNet models also contain keys starting with
# "double_blocks.", which we check for above. But, I'm afraid to modify this logic because it is so
# delicate.
"controlnet_blocks",
)
):
return ModelType.ControlNet
elif key.startswith(("image_proj.", "ip_adapter.")):
return ModelType.IPAdapter
@@ -623,6 +635,12 @@ class ControlNetCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
if "double_blocks.0.img_attn.qkv.weight" in checkpoint:
# TODO(ryand): Should I distinguish between XLabs, InstantX and other ControlNet models by implementing
# get_format()?
return BaseModelType.Flux
for key_name in (
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
"controlnet_mid_block.bias",

View File

@@ -0,0 +1,91 @@
# State dict keys for an XLabs FLUX ControlNet model. Intended to be used for unit tests.
# These keys were extracted from:
# https://huggingface.co/XLabs-AI/flux-controlnet-collections/blob/86ab1e915a389d5857135c00e0d350e9e38a9048/flux-canny-controlnet_v2.safetensors
state_dict_keys = [
"controlnet_blocks.0.bias",
"controlnet_blocks.0.weight",
"controlnet_blocks.1.bias",
"controlnet_blocks.1.weight",
"double_blocks.0.img_attn.norm.key_norm.scale",
"double_blocks.0.img_attn.norm.query_norm.scale",
"double_blocks.0.img_attn.proj.bias",
"double_blocks.0.img_attn.proj.weight",
"double_blocks.0.img_attn.qkv.bias",
"double_blocks.0.img_attn.qkv.weight",
"double_blocks.0.img_mlp.0.bias",
"double_blocks.0.img_mlp.0.weight",
"double_blocks.0.img_mlp.2.bias",
"double_blocks.0.img_mlp.2.weight",
"double_blocks.0.img_mod.lin.bias",
"double_blocks.0.img_mod.lin.weight",
"double_blocks.0.txt_attn.norm.key_norm.scale",
"double_blocks.0.txt_attn.norm.query_norm.scale",
"double_blocks.0.txt_attn.proj.bias",
"double_blocks.0.txt_attn.proj.weight",
"double_blocks.0.txt_attn.qkv.bias",
"double_blocks.0.txt_attn.qkv.weight",
"double_blocks.0.txt_mlp.0.bias",
"double_blocks.0.txt_mlp.0.weight",
"double_blocks.0.txt_mlp.2.bias",
"double_blocks.0.txt_mlp.2.weight",
"double_blocks.0.txt_mod.lin.bias",
"double_blocks.0.txt_mod.lin.weight",
"double_blocks.1.img_attn.norm.key_norm.scale",
"double_blocks.1.img_attn.norm.query_norm.scale",
"double_blocks.1.img_attn.proj.bias",
"double_blocks.1.img_attn.proj.weight",
"double_blocks.1.img_attn.qkv.bias",
"double_blocks.1.img_attn.qkv.weight",
"double_blocks.1.img_mlp.0.bias",
"double_blocks.1.img_mlp.0.weight",
"double_blocks.1.img_mlp.2.bias",
"double_blocks.1.img_mlp.2.weight",
"double_blocks.1.img_mod.lin.bias",
"double_blocks.1.img_mod.lin.weight",
"double_blocks.1.txt_attn.norm.key_norm.scale",
"double_blocks.1.txt_attn.norm.query_norm.scale",
"double_blocks.1.txt_attn.proj.bias",
"double_blocks.1.txt_attn.proj.weight",
"double_blocks.1.txt_attn.qkv.bias",
"double_blocks.1.txt_attn.qkv.weight",
"double_blocks.1.txt_mlp.0.bias",
"double_blocks.1.txt_mlp.0.weight",
"double_blocks.1.txt_mlp.2.bias",
"double_blocks.1.txt_mlp.2.weight",
"double_blocks.1.txt_mod.lin.bias",
"double_blocks.1.txt_mod.lin.weight",
"guidance_in.in_layer.bias",
"guidance_in.in_layer.weight",
"guidance_in.out_layer.bias",
"guidance_in.out_layer.weight",
"img_in.bias",
"img_in.weight",
"input_hint_block.0.bias",
"input_hint_block.0.weight",
"input_hint_block.10.bias",
"input_hint_block.10.weight",
"input_hint_block.12.bias",
"input_hint_block.12.weight",
"input_hint_block.14.bias",
"input_hint_block.14.weight",
"input_hint_block.2.bias",
"input_hint_block.2.weight",
"input_hint_block.4.bias",
"input_hint_block.4.weight",
"input_hint_block.6.bias",
"input_hint_block.6.weight",
"input_hint_block.8.bias",
"input_hint_block.8.weight",
"pos_embed_input.bias",
"pos_embed_input.weight",
"time_in.in_layer.bias",
"time_in.in_layer.weight",
"time_in.out_layer.bias",
"time_in.out_layer.weight",
"txt_in.bias",
"txt_in.weight",
"vector_in.in_layer.bias",
"vector_in.in_layer.weight",
"vector_in.out_layer.bias",
"vector_in.out_layer.weight",
]