mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-14 08:38:00 -05:00
111 lines
2.3 KiB
Python
111 lines
2.3 KiB
Python
# Initially pulled from https://github.com/black-forest-labs/flux
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Dict, Literal
|
|
|
|
from invokeai.backend.flux.model import FluxParams
|
|
from invokeai.backend.flux.modules.autoencoder import AutoEncoderParams
|
|
|
|
|
|
@dataclass
|
|
class ModelSpec:
|
|
params: FluxParams
|
|
ae_params: AutoEncoderParams
|
|
ckpt_path: str | None
|
|
ae_path: str | None
|
|
repo_id: str | None
|
|
repo_flow: str | None
|
|
repo_ae: str | None
|
|
|
|
|
|
# Preferred resolutions for Kontext models to avoid tiling artifacts
|
|
# These are the specific resolutions the model was trained on
|
|
PREFERED_KONTEXT_RESOLUTIONS = [
|
|
(672, 1568),
|
|
(688, 1504),
|
|
(720, 1456),
|
|
(752, 1392),
|
|
(800, 1328),
|
|
(832, 1248),
|
|
(880, 1184),
|
|
(944, 1104),
|
|
(1024, 1024),
|
|
(1104, 944),
|
|
(1184, 880),
|
|
(1248, 832),
|
|
(1328, 800),
|
|
(1392, 752),
|
|
(1456, 720),
|
|
(1504, 688),
|
|
(1568, 672),
|
|
]
|
|
|
|
|
|
max_seq_lengths: Dict[str, Literal[256, 512]] = {
|
|
"flux-dev": 512,
|
|
"flux-dev-fill": 512,
|
|
"flux-schnell": 256,
|
|
}
|
|
|
|
|
|
ae_params = {
|
|
"flux": AutoEncoderParams(
|
|
resolution=256,
|
|
in_channels=3,
|
|
ch=128,
|
|
out_ch=3,
|
|
ch_mult=[1, 2, 4, 4],
|
|
num_res_blocks=2,
|
|
z_channels=16,
|
|
scale_factor=0.3611,
|
|
shift_factor=0.1159,
|
|
)
|
|
}
|
|
|
|
|
|
params = {
|
|
"flux-dev": FluxParams(
|
|
in_channels=64,
|
|
vec_in_dim=768,
|
|
context_in_dim=4096,
|
|
hidden_size=3072,
|
|
mlp_ratio=4.0,
|
|
num_heads=24,
|
|
depth=19,
|
|
depth_single_blocks=38,
|
|
axes_dim=[16, 56, 56],
|
|
theta=10_000,
|
|
qkv_bias=True,
|
|
guidance_embed=True,
|
|
),
|
|
"flux-schnell": FluxParams(
|
|
in_channels=64,
|
|
vec_in_dim=768,
|
|
context_in_dim=4096,
|
|
hidden_size=3072,
|
|
mlp_ratio=4.0,
|
|
num_heads=24,
|
|
depth=19,
|
|
depth_single_blocks=38,
|
|
axes_dim=[16, 56, 56],
|
|
theta=10_000,
|
|
qkv_bias=True,
|
|
guidance_embed=False,
|
|
),
|
|
"flux-dev-fill": FluxParams(
|
|
in_channels=384,
|
|
out_channels=64,
|
|
vec_in_dim=768,
|
|
context_in_dim=4096,
|
|
hidden_size=3072,
|
|
mlp_ratio=4.0,
|
|
num_heads=24,
|
|
depth=19,
|
|
depth_single_blocks=38,
|
|
axes_dim=[16, 56, 56],
|
|
theta=10_000,
|
|
qkv_bias=True,
|
|
guidance_embed=True,
|
|
),
|
|
}
|