added support for loading bria transformer

This commit is contained in:
Ubuntu
2025-07-09 10:32:58 +00:00
committed by Kent Keirsey
parent dfc7835359
commit 7f3e8087ba
5 changed files with 1096 additions and 0 deletions

View File

View File

@@ -0,0 +1,314 @@
import math
import os
from typing import List, Optional, Union
import numpy as np
import torch
import torch.distributed as dist
from diffusers.utils import logging
from transformers import (
CLIPTextModel,
CLIPTextModelWithProjection,
CLIPTokenizer,
T5EncoderModel,
T5TokenizerFast,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def get_t5_prompt_embeds(
tokenizer: T5TokenizerFast,
text_encoder: T5EncoderModel,
prompt: Union[str, List[str], None] = None,
num_images_per_prompt: int = 1,
max_sequence_length: int = 128,
device: Optional[torch.device] = None,
):
device = device or text_encoder.device
if prompt is None:
prompt = ""
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
text_inputs = tokenizer(
prompt,
# padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
# Concat zeros to max_sequence
b, seq_len, dim = prompt_embeds.shape
if seq_len < max_sequence_length:
padding = torch.zeros(
(b, max_sequence_length - seq_len, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device
)
prompt_embeds = torch.concat([prompt_embeds, padding], dim=1)
prompt_embeds = prompt_embeds.to(device=device)
_, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
return prompt_embeds
# in order the get the same sigmas as in training and sample from them
def get_original_sigmas(num_train_timesteps=1000, num_inference_steps=1000):
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
sigmas = timesteps / num_train_timesteps
inds = [int(ind) for ind in np.linspace(0, num_train_timesteps - 1, num_inference_steps)]
new_sigmas = sigmas[inds]
return new_sigmas
def is_ng_none(negative_prompt):
return (
negative_prompt is None
or negative_prompt == ""
or (isinstance(negative_prompt, list) and negative_prompt[0] is None)
or (type(negative_prompt) == list and negative_prompt[0] == "")
)
class CudaTimerContext:
def __init__(self, times_arr):
self.times_arr = times_arr
def __enter__(self):
self.before_event = torch.cuda.Event(enable_timing=True)
self.after_event = torch.cuda.Event(enable_timing=True)
self.before_event.record()
def __exit__(self, type, value, traceback):
self.after_event.record()
torch.cuda.synchronize()
elapsed_time = self.before_event.elapsed_time(self.after_event) / 1000
self.times_arr.append(elapsed_time)
def get_env_prefix():
env = os.environ.get("CLOUD_PROVIDER", "AWS").upper()
if env == "AWS":
return "SM_CHANNEL"
elif env == "AZURE":
return "AZUREML_DATAREFERENCE"
raise Exception(f"Env {env} not supported")
def compute_density_for_timestep_sampling(
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
):
"""Compute the density for sampling the timesteps when doing SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if weighting_scheme == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
u = torch.nn.functional.sigmoid(u)
elif weighting_scheme == "mode":
u = torch.rand(size=(batch_size,), device="cpu")
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
else:
u = torch.rand(size=(batch_size,), device="cpu")
return u
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
"""Computes loss weighting scheme for SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if weighting_scheme == "sigma_sqrt":
weighting = (sigmas**-2.0).float()
elif weighting_scheme == "cosmap":
bot = 1 - 2 * sigmas + 2 * sigmas**2
weighting = 2 / (math.pi * bot)
else:
weighting = torch.ones_like(sigmas)
return weighting
def initialize_distributed():
# Initialize the process group for distributed training
dist.init_process_group("nccl")
# Get the current process's rank (ID) and the total number of processes (world size)
rank = dist.get_rank()
world_size = dist.get_world_size()
print(f"Initialized distributed training: Rank {rank}/{world_size}")
def get_clip_prompt_embeds(
text_encoder: CLIPTextModel,
text_encoder_2: CLIPTextModelWithProjection,
tokenizer: CLIPTokenizer,
tokenizer_2: CLIPTokenizer,
prompt: Union[str, List[str]] = None,
num_images_per_prompt: int = 1,
max_sequence_length: int = 77,
device: Optional[torch.device] = None,
):
device = device or text_encoder.device
assert max_sequence_length == tokenizer.model_max_length
prompt = [prompt] if isinstance(prompt, str) else prompt
# Define tokenizers and text encoders
tokenizers = [tokenizer, tokenizer_2]
text_encoders = [text_encoder, text_encoder_2]
# textual inversion: process multi-vector tokens if necessary
prompt_embeds_list = []
prompts = [prompt, prompt]
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders, strict=False):
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds = text_encoder(text_input_ids.to(text_encoder.device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
prompt_embeds_list.append(prompt_embeds)
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
return prompt_embeds, pooled_prompt_embeds
def get_1d_rotary_pos_embed(
dim: int,
pos: Union[np.ndarray, int],
theta: float = 10000.0,
use_real=False,
linear_factor=1.0,
ntk_factor=1.0,
repeat_interleave_real=True,
freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
):
"""
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
data type.
Args:
dim (`int`): Dimension of the frequency tensor.
pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
theta (`float`, *optional*, defaults to 10000.0):
Scaling factor for frequency computation. Defaults to 10000.0.
use_real (`bool`, *optional*):
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
linear_factor (`float`, *optional*, defaults to 1.0):
Scaling factor for the context extrapolation. Defaults to 1.0.
ntk_factor (`float`, *optional*, defaults to 1.0):
Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
repeat_interleave_real (`bool`, *optional*, defaults to `True`):
If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
Otherwise, they are concateanted with themselves.
freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
the dtype of the frequency tensor.
Returns:
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
"""
assert dim % 2 == 0
if isinstance(pos, int):
pos = torch.arange(pos)
if isinstance(pos, np.ndarray):
pos = torch.from_numpy(pos) # type: ignore # [S]
theta = theta * ntk_factor
freqs = (
1.0
/ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
/ linear_factor
) # [D/2]
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
if use_real and repeat_interleave_real:
# flux, hunyuan-dit, cogvideox
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
return freqs_cos, freqs_sin
elif use_real:
# stable audio, allegro
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
return freqs_cos, freqs_sin
else:
# lumina
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
return freqs_cis
class FluxPosEmbed(torch.nn.Module):
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
def __init__(self, theta: int, axes_dim: List[int]):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
cos_out = []
sin_out = []
pos = ids.float()
is_mps = ids.device.type == "mps"
freqs_dtype = torch.float32 if is_mps else torch.float64
for i in range(n_axes):
cos, sin = get_1d_rotary_pos_embed(
self.axes_dim[i],
pos[:, i],
theta=self.theta,
repeat_interleave_real=True,
use_real=True,
freqs_dtype=freqs_dtype,
)
cos_out.append(cos)
sin_out.append(sin)
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
return freqs_cos, freqs_sin

View File

@@ -0,0 +1,459 @@
#!/usr/bin/env python
"""
Bria TextΓÇætoΓÇæImage Pipeline (GPUΓÇæready)
Using your local Bria checkpoints.
"""
from typing import Optional
import numpy as np
import torch
import torch.nn as nn
# Your bria_utils imports
from .bria_utils import get_original_sigmas, get_t5_prompt_embeds, is_ng_none
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
from PIL import Image
from tqdm import tqdm # add this at the top of your file
# Your custom transformer import
from .transformer_bria import BriaTransformer2DModel
from transformers import T5EncoderModel, T5TokenizerFast
# -----------------------------------------------------------------------------
# 1. Model Loader
# -----------------------------------------------------------------------------
class BriaModelLoader:
def __init__(
self,
transformer_ckpt: str,
vae_ckpt: str,
text_encoder_ckpt: str,
tokenizer_ckpt: str,
device: torch.device,
):
self.device = device
# print("Loading Bria Transformer from", transformer_ckpt)
# self.transformer = BriaTransformer2DModel.from_pretrained(transformer_ckpt, torch_dtype=torch.bfloat16).to(device)
# print("Loading VAE from", vae_ckpt)
# self.vae = AutoencoderKL.from_pretrained(vae_ckpt, torch_dtype=torch.float32).to(device)
# print("Loading T5 Encoder from", text_encoder_ckpt)
# self.text_encoder = T5EncoderModel.from_pretrained(text_encoder_ckpt, torch_dtype=torch.float16).to(device)
# print("Loading Tokenizer from", tokenizer_ckpt)
# self.tokenizer = T5TokenizerFast.from_pretrained(tokenizer_ckpt, legacy=False)
self.transformer = BriaTransformer2DModel.from_pretrained(transformer_ckpt, torch_dtype=torch.float16).to(
device
)
self.vae = AutoencoderKL.from_pretrained(vae_ckpt, torch_dtype=torch.float16).to(device)
self.text_encoder = T5EncoderModel.from_pretrained(text_encoder_ckpt, torch_dtype=torch.float16).to(device)
self.tokenizer = T5TokenizerFast.from_pretrained(tokenizer_ckpt)
def get(self):
return {
"transformer": self.transformer,
"vae": self.vae,
"text_encoder": self.text_encoder,
"tokenizer": self.tokenizer,
}
# -----------------------------------------------------------------------------
# 2. Text Encoder (uses bria_utils)
# -----------------------------------------------------------------------------
class BriaTextEncoder:
def __init__(
self,
text_encoder: T5EncoderModel,
tokenizer: T5TokenizerFast,
device: torch.device,
max_length: int = 128,
):
self.model = text_encoder.to(device)
self.tokenizer = tokenizer
self.device = device
self.max_length = max_length
def encode(
self,
prompt: str,
negative_prompt: Optional[str] = None,
num_images_per_prompt: int = 1,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# 1) get positive embeddings
pos = get_t5_prompt_embeds(
tokenizer=self.tokenizer,
text_encoder=self.model,
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=self.max_length,
device=self.device,
)
# 2) get negative or zeros
if negative_prompt is None or is_ng_none(negative_prompt):
neg = torch.zeros_like(pos)
else:
neg = get_t5_prompt_embeds(
tokenizer=self.tokenizer,
text_encoder=self.model,
prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=self.max_length,
device=self.device,
)
# 3) build text_ids: shape [S_text, 3]
# S_text = number of tokens = pos.shape[1]
S_text = pos.shape[1]
text_ids = torch.zeros((1, S_text, 3), device=self.device, dtype=torch.long)
text_ids = torch.zeros((S_text, 3), device=self.device, dtype=torch.long)
print(f"Text embeds shapes → pos: {pos.shape}, neg: {neg.shape}, text_ids: {text_ids.shape}")
return pos, neg, text_ids
# -----------------------------------------------------------------------------
# 3. Latent Sampler
# -----------------------------------------------------------------------------
class BriaLatentSampler:
def __init__(self, transformer: BriaTransformer2DModel, vae: AutoencoderKL, device: torch.device):
self.device = device
self.latent_channels = transformer.config.in_channels
# self.latent_height = vae.config.sample_size
# self.latent_width = vae.config.sample_size
self.latent_height = 128
self.latent_width = 128
@staticmethod
def _prepare_latent_image_ids(batch_size: int, height: int, width: int, device: torch.device, dtype: torch.dtype):
# Build the same img_ids FluxPipeline.prepare_latents would use
latent_image_ids = torch.zeros((height, width, 3), device=device, dtype=dtype)
latent_image_ids[..., 1] = torch.arange(height, device=device)[:, None]
latent_image_ids[..., 2] = torch.arange(width, device=device)[None, :]
# reshape to [1, height*width, 3] then repeat for batch
latent_image_ids = latent_image_ids.view(1, height * width, 3)
return latent_image_ids.repeat(batch_size, 1, 1)
def sample(self, batch_size: int = 1, seed: int = 0) -> tuple[torch.Tensor, torch.Tensor]:
gen = torch.Generator(device=self.device).manual_seed(seed)
# 1) sample & pack the noise exactly as before
shrunk = self.latent_channels // 4
noise4d = torch.randn(
(batch_size, shrunk, self.latent_height, self.latent_width),
device=self.device,
generator=gen,
)
latents = (
noise4d.view(batch_size, shrunk, self.latent_height // 2, 2, self.latent_width // 2, 2)
.permute(0, 2, 4, 1, 3, 5)
.reshape(batch_size, (self.latent_height // 2) * (self.latent_width // 2), shrunk * 4)
)
# 2) build the matching latent_image_ids
latent_image_ids = self._prepare_latent_image_ids(
batch_size,
self.latent_height // 2,
self.latent_width // 2,
device=self.device,
dtype=torch.long,
)
if latent_image_ids.ndim == 3 and latent_image_ids.shape[0] == 1:
latent_image_ids = latent_image_ids[0] # [S_img , 3]
latent_image_ids = latent_image_ids.squeeze(0)
print(f"Sampled & packed latents: {latents.shape}")
return latents, latent_image_ids
# -----------------------------------------------------------------------------
# 4. Denoising Loop (uses bria_utils for σ schedule)
# -----------------------------------------------------------------------------
class BriaDenoise:
def __init__(
self,
transformer: nn.Module,
scheduler_name: str,
device: torch.device,
num_train_timesteps: int,
num_inference_steps: int,
**sched_kwargs,
):
self.transformer = transformer.to(device)
self.device = device
# Build scheduler
if scheduler_name == "flow_match":
from diffusers import FlowMatchEulerDiscreteScheduler
self.scheduler = FlowMatchEulerDiscreteScheduler.from_config(transformer.config, **sched_kwargs)
else:
from diffusers import DDIMScheduler
self.scheduler = DDIMScheduler(**sched_kwargs)
# Use your exact σ schedule from bria_utils
from bria_utils import get_original_sigmas
sigmas = get_original_sigmas(
num_train_timesteps=num_train_timesteps,
num_inference_steps=num_inference_steps,
)
self.scheduler.set_timesteps(
num_inference_steps=None,
timesteps=None,
sigmas=sigmas,
device=device,
)
# allow early exit
self.interrupt = False
# will be set in denoise()
self._guidance_scale = 1.0
self._joint_attention_kwargs = {}
self.transformer = transformer.to(device)
self.device = device
@property
def guidance_scale(self) -> float:
return self._guidance_scale
@property
def do_classifier_free_guidance(self) -> bool:
return self.guidance_scale > 1.0
@property
def joint_attention_kwargs(self) -> dict:
return self._joint_attention_kwargs
@torch.no_grad()
def denoise(
self,
latents: torch.Tensor, # [B, seq_len, C_hidden]
latent_image_ids: torch.Tensor, # [B, seq_len, 3]
prompt_embeds: torch.Tensor, # [B, S_text, D]
negative_prompt_embeds: torch.Tensor, # [B, S_text, D]
text_ids: torch.Tensor, # [B, S_text, 3]
num_inference_steps: int = 30,
guidance_scale: float = 5.0,
normalize: bool = False,
clip_value: float | None = None,
seed: int = 0,
) -> torch.Tensor:
# 0) Quick cast & setup
device = self.device
# ensure dtype matches transformer
target_dtype = next(self.transformer.parameters()).dtype
latents = latents.to(device, dtype=target_dtype)
prompt_embeds = prompt_embeds.to(device, dtype=target_dtype)
negative_prompt_embeds = negative_prompt_embeds.to(device, dtype=target_dtype)
# replicate reference encode_prompt behaviour
if negative_prompt_embeds is None:
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
if guidance_scale > 1.0:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
self._guidance_scale = guidance_scale
# 1) Prepare FlowΓÇæMatch timesteps identical to reference pipeline
if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler) and getattr(
self.scheduler.config, "use_dynamic_shifting", False
):
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
image_seq_len = latents.shape[1]
mu = calculate_shift(image_seq_len, 256, 16_384, 0.25, 0.75)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, None, sigmas, mu=mu
)
else:
sigmas = get_original_sigmas(
num_train_timesteps=self.scheduler.config.num_train_timesteps,
num_inference_steps=num_inference_steps,
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, None, sigmas
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
# 2) Loop with progress bar
with tqdm(total=num_inference_steps, desc="Denoising", unit="step") as progress_bar:
for i, t in enumerate(timesteps):
# a) expand for CFG?
latent_model_input = torch.cat([latents] * 2, dim=0) if self.do_classifier_free_guidance else latents
# b) scale model input if needed
if not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# c) broadcast timestep
timestep = t.expand(latent_model_input.shape[0])
# d) predict noise
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
txt_ids=text_ids,
img_ids=latent_image_ids,
)[0]
# e) classifierΓÇæfree guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
cfg_noise_pred_text = noise_pred_text.std()
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# f) optional normalize/clip
if normalize:
noise_pred = noise_pred * (0.7 * (cfg_noise_pred_text / noise_pred.std())) + 0.3 * noise_pred
if clip_value:
noise_pred = noise_pred.clamp(-clip_value, clip_value)
# g) scheduler step, inΓÇæplace
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if (i + 1) % 5 == 0 or i == len(timesteps) - 1:
progress_bar.update(5 if i + 1 < len(timesteps) else (len(timesteps) % 5))
# # j) XLA sync
# if XLA_AVAILABLE:
# xm.mark_step()
# 3) Return the final packed latents (still [B, seq_len, C_hidden])
return latents
# -----------------------------------------------------------------------------
# 5. Latents → Image
# -----------------------------------------------------------------------------
class BriaLatentsToImage:
def __init__(self, vae: AutoencoderKL, device: torch.device):
self.vae = vae.to(device)
self.device = device
@torch.no_grad()
def decode(self, latents: torch.Tensor) -> list[Image.Image]:
"""
Accepts either of the two packed shapes that come out of the denoiser
• [B , S , 16] – 3‑D, where S = H² (e.g. 16 384 for 1024×1024)
ΓÇó [B , 1 , S , 16] ΓÇô 4ΓÇæD misΓÇæordered (what caused your crash)
Converts them to the VAEΓÇÖs expected shape [B , 4 , H , W] before decoding.
"""
# ---- 1. UnΓÇæpack to (B , 4 , H , W) ----------------------------------
if latents.ndim == 3: # (B,S,16)
B, S, C = latents.shape
H2 = int(S**0.5) # 128 for 1024×1024
latents = (
latents.view(B, H2, H2, 4, 2, 2) # split channels into 4×(2×2)
.permute(0, 3, 1, 4, 2, 5) # (B,4,H2,2,W2,2)
.reshape(B, 4, H2 * 2, H2 * 2) # (B,4,H,W)
)
elif latents.ndim == 4 and latents.shape[1] == 1: # (B,1,S,16)
B, _, S, C = latents.shape
H2 = int(S**0.5)
latents = (
latents.squeeze(1) # -> (B,S,16)
.view(B, H2, H2, 4, 2, 2)
.permute(0, 3, 1, 4, 2, 5)
.reshape(B, 4, H2 * 2, H2 * 2)
)
# else: already (B,4,H,W)
# ---- 2. Standard VAE decode -----------------------------------------
shift = 0 if self.vae.config.shift_factor is None else self.vae.config.shift_factor
latents = (latents / self.vae.config.scaling_factor) + shift
# 1. temporarily move VAE to fp32 for the forward pass
self.vae.to(dtype=torch.float32)
images = self.vae.decode(latents.to(torch.float32)).sample # fullΓÇæprecision decode
self.vae.to(dtype=torch.bfloat16) # cast to fp32 **after** decode
images = (images.clamp(-1, 1) + 1) / 2 # [0,1] fp32
images = (images.cpu().permute(0, 2, 3, 1).numpy() * 255).astype("uint8")
return [Image.fromarray(img) for img in images]
# -----------------------------------------------------------------------------
# Main: Assemble & Run
# -----------------------------------------------------------------------------
def main():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
# ΓöÇΓöÇΓöÇ Use your actual checkpoint locations ΓöÇΓöÇΓöÇ
transformer_ckpt = "/home/ubuntu/invoke_local_nodes/bria_3_1/transformer"
vae_ckpt = "/home/ubuntu/invoke_local_nodes/bria_3_1/vae"
text_encoder_ckpt = "/home/ubuntu/invoke_local_nodes/bria_3_1/text_encoder"
tokenizer_ckpt = "/home/ubuntu/invoke_local_nodes/bria_3_1/tokenizer"
# 1. Load models
loader = BriaModelLoader(
transformer_ckpt,
vae_ckpt,
text_encoder_ckpt,
tokenizer_ckpt,
device,
)
mdl = loader.get()
# if diffusers.__version__ >= "0.27.0":
# mdl["transformer"].enable_xformers_memory_efficient_attention() # now safe
# else:
# mdl["transformer"].disable_xformers_memory_efficient_attention() # keep quality
# 2. Encode prompt ΓÇö now capture text_ids as well
text_enc = BriaTextEncoder(mdl["text_encoder"], mdl["tokenizer"], device)
pos_embeds, neg_embeds, text_ids = text_enc.encode(
prompt="3d rendered image, landscape made out of ice cream, rich ice cream textures, ice cream-valley , with a milky ice cream river, the ice cream has rich texture with visible chocolate chunks and intricate details, in the background an air balloon floats over the vally, in the sky visible dramatic like clouds, brown-chocolate color white and pink pallet, drama, beautiful surreal landscape, polarizing lens, very high contrast, 3d rendered realistic",
negative_prompt=None,
num_images_per_prompt=1,
)
# 3. Sample initial noise → get both latents & latent_image_ids
sampler = BriaLatentSampler(mdl["transformer"], mdl["vae"], device)
init_latents, latent_image_ids = sampler.sample(batch_size=1, seed=1249141701)
# 4. Denoise ΓÇö now passing latent_image_ids and text_ids
denoiser = BriaDenoise(
transformer=mdl["transformer"],
scheduler_name="flow_match",
device=device,
num_train_timesteps=1000,
num_inference_steps=30,
base_shift=0.5,
max_shift=1.15,
)
final_latents = denoiser.denoise(
init_latents,
latent_image_ids,
pos_embeds,
neg_embeds,
text_ids,
num_inference_steps=30,
guidance_scale=5.0,
seed=1249141701,
)
# 5. Decode
decoder = BriaLatentsToImage(mdl["vae"], device)
images = decoder.decode(final_latents)
for i, img in enumerate(images):
img.save(f"bria_output_{i}.png")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,320 @@
from typing import Any, Dict, List, Optional, Union
import numpy as np
import torch
import torch.nn as nn
from .bria_utils import FluxPosEmbed as EmbedND
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
from diffusers.models.embeddings import TimestepEmbedding, get_timestep_embedding
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.normalization import AdaLayerNormContinuous
from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class Timesteps(nn.Module):
def __init__(
self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1, time_theta=10000
):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
self.scale = scale
self.time_theta = time_theta
def forward(self, timesteps):
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
scale=self.scale,
max_period=self.time_theta,
)
return t_emb
class TimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, time_theta):
super().__init__()
self.time_proj = Timesteps(
num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, time_theta=time_theta
)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
def forward(self, timestep, dtype):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=dtype)) # (N, D)
return timesteps_emb
"""
Based on FluxPipeline with several changes:
- no pooled embeddings
- We use zero padding for prompts
- No guidance embedding since this is not a distilled version
"""
class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
"""
The Transformer model introduced in Flux.
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
Parameters:
patch_size (`int`): Patch size to turn the input data into small patches.
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
patch_size: int = 1,
in_channels: int = 64,
num_layers: int = 19,
num_single_layers: int = 38,
attention_head_dim: int = 128,
num_attention_heads: int = 24,
joint_attention_dim: int = 4096,
pooled_projection_dim: int = None,
guidance_embeds: bool = False,
axes_dims_rope: List[int] = [16, 56, 56],
rope_theta=10000,
time_theta=10000,
):
super().__init__()
self.out_channels = in_channels
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
self.pos_embed = EmbedND(theta=rope_theta, axes_dim=axes_dims_rope)
self.time_embed = TimestepProjEmbeddings(embedding_dim=self.inner_dim, time_theta=time_theta)
# if pooled_projection_dim:
# self.pooled_text_embed = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim=self.inner_dim, act_fn="silu")
if guidance_embeds:
self.guidance_embed = TimestepProjEmbeddings(embedding_dim=self.inner_dim)
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
self.transformer_blocks = nn.ModuleList(
[
FluxTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
)
for i in range(self.config.num_layers)
]
)
self.single_transformer_blocks = nn.ModuleList(
[
FluxSingleTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
)
for i in range(self.config.num_single_layers)
]
)
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
controlnet_block_samples=None,
controlnet_single_block_samples=None,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
The [`FluxTransformer2DModel`] forward method.
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.x_embedder(hidden_states)
timestep = timestep.to(hidden_states.dtype)
if guidance is not None:
guidance = guidance.to(hidden_states.dtype)
else:
guidance = None
# temb = (
# self.time_text_embed(timestep, pooled_projections)
# if guidance is None
# else self.time_text_embed(timestep, guidance, pooled_projections)
# )
temb = self.time_embed(timestep, dtype=hidden_states.dtype)
# if pooled_projections:
# temb+=self.pooled_text_embed(pooled_projections)
if guidance:
temb += self.guidance_embed(guidance, dtype=hidden_states.dtype)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if len(txt_ids.shape) == 2:
ids = torch.cat((txt_ids, img_ids), dim=0)
else:
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pos_embed(ids)
for index_block, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
# controlnet residual
if controlnet_block_samples is not None:
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
# controlnet residual
if controlnet_single_block_samples is not None:
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+ controlnet_single_block_samples[index_block // interval_control]
)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@@ -80,7 +80,10 @@ class GenericDiffusersLoader(ModelLoader):
"transformers",
"invokeai.backend.quantization.fast_quantized_transformers_model",
"invokeai.backend.quantization.fast_quantized_diffusion_model",
"transformer_bria",
]:
if module == "transformer_bria":
module = "invokeai.backend.bria.transformer_bria"
res_type = sys.modules[module]
else:
res_type = sys.modules["diffusers"].pipelines