mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
added support for loading bria transformer
This commit is contained in:
0
invokeai/backend/bria/__init__.py
Normal file
0
invokeai/backend/bria/__init__.py
Normal file
314
invokeai/backend/bria/bria_utils.py
Normal file
314
invokeai/backend/bria/bria_utils.py
Normal file
@@ -0,0 +1,314 @@
|
||||
import math
|
||||
import os
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from diffusers.utils import logging
|
||||
from transformers import (
|
||||
CLIPTextModel,
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPTokenizer,
|
||||
T5EncoderModel,
|
||||
T5TokenizerFast,
|
||||
)
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def get_t5_prompt_embeds(
|
||||
tokenizer: T5TokenizerFast,
|
||||
text_encoder: T5EncoderModel,
|
||||
prompt: Union[str, List[str], None] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
max_sequence_length: int = 128,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
device = device or text_encoder.device
|
||||
|
||||
if prompt is None:
|
||||
prompt = ""
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
# padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because `max_sequence_length` is set to "
|
||||
f" {max_sequence_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
|
||||
|
||||
# Concat zeros to max_sequence
|
||||
b, seq_len, dim = prompt_embeds.shape
|
||||
if seq_len < max_sequence_length:
|
||||
padding = torch.zeros(
|
||||
(b, max_sequence_length - seq_len, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device
|
||||
)
|
||||
prompt_embeds = torch.concat([prompt_embeds, padding], dim=1)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(device=device)
|
||||
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
|
||||
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
|
||||
# in order the get the same sigmas as in training and sample from them
|
||||
def get_original_sigmas(num_train_timesteps=1000, num_inference_steps=1000):
|
||||
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
|
||||
sigmas = timesteps / num_train_timesteps
|
||||
|
||||
inds = [int(ind) for ind in np.linspace(0, num_train_timesteps - 1, num_inference_steps)]
|
||||
new_sigmas = sigmas[inds]
|
||||
return new_sigmas
|
||||
|
||||
|
||||
def is_ng_none(negative_prompt):
|
||||
return (
|
||||
negative_prompt is None
|
||||
or negative_prompt == ""
|
||||
or (isinstance(negative_prompt, list) and negative_prompt[0] is None)
|
||||
or (type(negative_prompt) == list and negative_prompt[0] == "")
|
||||
)
|
||||
|
||||
|
||||
class CudaTimerContext:
|
||||
def __init__(self, times_arr):
|
||||
self.times_arr = times_arr
|
||||
|
||||
def __enter__(self):
|
||||
self.before_event = torch.cuda.Event(enable_timing=True)
|
||||
self.after_event = torch.cuda.Event(enable_timing=True)
|
||||
self.before_event.record()
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
self.after_event.record()
|
||||
torch.cuda.synchronize()
|
||||
elapsed_time = self.before_event.elapsed_time(self.after_event) / 1000
|
||||
self.times_arr.append(elapsed_time)
|
||||
|
||||
|
||||
def get_env_prefix():
|
||||
env = os.environ.get("CLOUD_PROVIDER", "AWS").upper()
|
||||
if env == "AWS":
|
||||
return "SM_CHANNEL"
|
||||
elif env == "AZURE":
|
||||
return "AZUREML_DATAREFERENCE"
|
||||
|
||||
raise Exception(f"Env {env} not supported")
|
||||
|
||||
|
||||
def compute_density_for_timestep_sampling(
|
||||
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
|
||||
):
|
||||
"""Compute the density for sampling the timesteps when doing SD3 training.
|
||||
|
||||
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
||||
|
||||
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
|
||||
"""
|
||||
if weighting_scheme == "logit_normal":
|
||||
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
|
||||
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
|
||||
u = torch.nn.functional.sigmoid(u)
|
||||
elif weighting_scheme == "mode":
|
||||
u = torch.rand(size=(batch_size,), device="cpu")
|
||||
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
|
||||
else:
|
||||
u = torch.rand(size=(batch_size,), device="cpu")
|
||||
return u
|
||||
|
||||
|
||||
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
|
||||
"""Computes loss weighting scheme for SD3 training.
|
||||
|
||||
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
||||
|
||||
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
|
||||
"""
|
||||
if weighting_scheme == "sigma_sqrt":
|
||||
weighting = (sigmas**-2.0).float()
|
||||
elif weighting_scheme == "cosmap":
|
||||
bot = 1 - 2 * sigmas + 2 * sigmas**2
|
||||
weighting = 2 / (math.pi * bot)
|
||||
else:
|
||||
weighting = torch.ones_like(sigmas)
|
||||
return weighting
|
||||
|
||||
|
||||
def initialize_distributed():
|
||||
# Initialize the process group for distributed training
|
||||
dist.init_process_group("nccl")
|
||||
|
||||
# Get the current process's rank (ID) and the total number of processes (world size)
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
print(f"Initialized distributed training: Rank {rank}/{world_size}")
|
||||
|
||||
|
||||
def get_clip_prompt_embeds(
|
||||
text_encoder: CLIPTextModel,
|
||||
text_encoder_2: CLIPTextModelWithProjection,
|
||||
tokenizer: CLIPTokenizer,
|
||||
tokenizer_2: CLIPTokenizer,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
max_sequence_length: int = 77,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
device = device or text_encoder.device
|
||||
assert max_sequence_length == tokenizer.model_max_length
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
# Define tokenizers and text encoders
|
||||
tokenizers = [tokenizer, tokenizer_2]
|
||||
text_encoders = [text_encoder, text_encoder_2]
|
||||
|
||||
# textual inversion: process multi-vector tokens if necessary
|
||||
prompt_embeds_list = []
|
||||
prompts = [prompt, prompt]
|
||||
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders, strict=False):
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids
|
||||
prompt_embeds = text_encoder(text_input_ids.to(text_encoder.device), output_hidden_states=True)
|
||||
|
||||
# We are only ALWAYS interested in the pooled output of the final text encoder
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
prompt_embeds = prompt_embeds.hidden_states[-2]
|
||||
|
||||
prompt_embeds_list.append(prompt_embeds)
|
||||
|
||||
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
||||
bs_embed * num_images_per_prompt, -1
|
||||
)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
|
||||
def get_1d_rotary_pos_embed(
|
||||
dim: int,
|
||||
pos: Union[np.ndarray, int],
|
||||
theta: float = 10000.0,
|
||||
use_real=False,
|
||||
linear_factor=1.0,
|
||||
ntk_factor=1.0,
|
||||
repeat_interleave_real=True,
|
||||
freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
|
||||
):
|
||||
"""
|
||||
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
||||
|
||||
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
|
||||
index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
|
||||
data type.
|
||||
|
||||
Args:
|
||||
dim (`int`): Dimension of the frequency tensor.
|
||||
pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
|
||||
theta (`float`, *optional*, defaults to 10000.0):
|
||||
Scaling factor for frequency computation. Defaults to 10000.0.
|
||||
use_real (`bool`, *optional*):
|
||||
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
||||
linear_factor (`float`, *optional*, defaults to 1.0):
|
||||
Scaling factor for the context extrapolation. Defaults to 1.0.
|
||||
ntk_factor (`float`, *optional*, defaults to 1.0):
|
||||
Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
|
||||
repeat_interleave_real (`bool`, *optional*, defaults to `True`):
|
||||
If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
|
||||
Otherwise, they are concateanted with themselves.
|
||||
freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
|
||||
the dtype of the frequency tensor.
|
||||
Returns:
|
||||
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
|
||||
"""
|
||||
assert dim % 2 == 0
|
||||
|
||||
if isinstance(pos, int):
|
||||
pos = torch.arange(pos)
|
||||
if isinstance(pos, np.ndarray):
|
||||
pos = torch.from_numpy(pos) # type: ignore # [S]
|
||||
|
||||
theta = theta * ntk_factor
|
||||
freqs = (
|
||||
1.0
|
||||
/ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
|
||||
/ linear_factor
|
||||
) # [D/2]
|
||||
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
|
||||
if use_real and repeat_interleave_real:
|
||||
# flux, hunyuan-dit, cogvideox
|
||||
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
|
||||
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
|
||||
return freqs_cos, freqs_sin
|
||||
elif use_real:
|
||||
# stable audio, allegro
|
||||
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
|
||||
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
|
||||
return freqs_cos, freqs_sin
|
||||
else:
|
||||
# lumina
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
|
||||
return freqs_cis
|
||||
|
||||
|
||||
class FluxPosEmbed(torch.nn.Module):
|
||||
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
|
||||
def __init__(self, theta: int, axes_dim: List[int]):
|
||||
super().__init__()
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
||||
n_axes = ids.shape[-1]
|
||||
cos_out = []
|
||||
sin_out = []
|
||||
pos = ids.float()
|
||||
is_mps = ids.device.type == "mps"
|
||||
freqs_dtype = torch.float32 if is_mps else torch.float64
|
||||
for i in range(n_axes):
|
||||
cos, sin = get_1d_rotary_pos_embed(
|
||||
self.axes_dim[i],
|
||||
pos[:, i],
|
||||
theta=self.theta,
|
||||
repeat_interleave_real=True,
|
||||
use_real=True,
|
||||
freqs_dtype=freqs_dtype,
|
||||
)
|
||||
cos_out.append(cos)
|
||||
sin_out.append(sin)
|
||||
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
|
||||
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
|
||||
return freqs_cos, freqs_sin
|
||||
459
invokeai/backend/bria/pipeline.py
Normal file
459
invokeai/backend/bria/pipeline.py
Normal file
@@ -0,0 +1,459 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Bria TextΓÇætoΓÇæImage Pipeline (GPUΓÇæready)
|
||||
Using your local Bria checkpoints.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# Your bria_utils imports
|
||||
from .bria_utils import get_original_sigmas, get_t5_prompt_embeds, is_ng_none
|
||||
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
|
||||
from PIL import Image
|
||||
from tqdm import tqdm # add this at the top of your file
|
||||
|
||||
# Your custom transformer import
|
||||
from .transformer_bria import BriaTransformer2DModel
|
||||
from transformers import T5EncoderModel, T5TokenizerFast
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# 1. Model Loader
|
||||
# -----------------------------------------------------------------------------
|
||||
class BriaModelLoader:
|
||||
def __init__(
|
||||
self,
|
||||
transformer_ckpt: str,
|
||||
vae_ckpt: str,
|
||||
text_encoder_ckpt: str,
|
||||
tokenizer_ckpt: str,
|
||||
device: torch.device,
|
||||
):
|
||||
self.device = device
|
||||
|
||||
# print("Loading Bria Transformer from", transformer_ckpt)
|
||||
# self.transformer = BriaTransformer2DModel.from_pretrained(transformer_ckpt, torch_dtype=torch.bfloat16).to(device)
|
||||
|
||||
# print("Loading VAE from", vae_ckpt)
|
||||
# self.vae = AutoencoderKL.from_pretrained(vae_ckpt, torch_dtype=torch.float32).to(device)
|
||||
|
||||
# print("Loading T5 Encoder from", text_encoder_ckpt)
|
||||
# self.text_encoder = T5EncoderModel.from_pretrained(text_encoder_ckpt, torch_dtype=torch.float16).to(device)
|
||||
|
||||
# print("Loading Tokenizer from", tokenizer_ckpt)
|
||||
# self.tokenizer = T5TokenizerFast.from_pretrained(tokenizer_ckpt, legacy=False)
|
||||
self.transformer = BriaTransformer2DModel.from_pretrained(transformer_ckpt, torch_dtype=torch.float16).to(
|
||||
device
|
||||
)
|
||||
self.vae = AutoencoderKL.from_pretrained(vae_ckpt, torch_dtype=torch.float16).to(device)
|
||||
self.text_encoder = T5EncoderModel.from_pretrained(text_encoder_ckpt, torch_dtype=torch.float16).to(device)
|
||||
self.tokenizer = T5TokenizerFast.from_pretrained(tokenizer_ckpt)
|
||||
|
||||
def get(self):
|
||||
return {
|
||||
"transformer": self.transformer,
|
||||
"vae": self.vae,
|
||||
"text_encoder": self.text_encoder,
|
||||
"tokenizer": self.tokenizer,
|
||||
}
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# 2. Text Encoder (uses bria_utils)
|
||||
# -----------------------------------------------------------------------------
|
||||
class BriaTextEncoder:
|
||||
def __init__(
|
||||
self,
|
||||
text_encoder: T5EncoderModel,
|
||||
tokenizer: T5TokenizerFast,
|
||||
device: torch.device,
|
||||
max_length: int = 128,
|
||||
):
|
||||
self.model = text_encoder.to(device)
|
||||
self.tokenizer = tokenizer
|
||||
self.device = device
|
||||
self.max_length = max_length
|
||||
|
||||
def encode(
|
||||
self,
|
||||
prompt: str,
|
||||
negative_prompt: Optional[str] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# 1) get positive embeddings
|
||||
pos = get_t5_prompt_embeds(
|
||||
tokenizer=self.tokenizer,
|
||||
text_encoder=self.model,
|
||||
prompt=prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=self.max_length,
|
||||
device=self.device,
|
||||
)
|
||||
# 2) get negative or zeros
|
||||
if negative_prompt is None or is_ng_none(negative_prompt):
|
||||
neg = torch.zeros_like(pos)
|
||||
else:
|
||||
neg = get_t5_prompt_embeds(
|
||||
tokenizer=self.tokenizer,
|
||||
text_encoder=self.model,
|
||||
prompt=negative_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=self.max_length,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
# 3) build text_ids: shape [S_text, 3]
|
||||
# S_text = number of tokens = pos.shape[1]
|
||||
S_text = pos.shape[1]
|
||||
text_ids = torch.zeros((1, S_text, 3), device=self.device, dtype=torch.long)
|
||||
text_ids = torch.zeros((S_text, 3), device=self.device, dtype=torch.long)
|
||||
|
||||
print(f"Text embeds shapes → pos: {pos.shape}, neg: {neg.shape}, text_ids: {text_ids.shape}")
|
||||
return pos, neg, text_ids
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# 3. Latent Sampler
|
||||
# -----------------------------------------------------------------------------
|
||||
class BriaLatentSampler:
|
||||
def __init__(self, transformer: BriaTransformer2DModel, vae: AutoencoderKL, device: torch.device):
|
||||
self.device = device
|
||||
self.latent_channels = transformer.config.in_channels
|
||||
# self.latent_height = vae.config.sample_size
|
||||
# self.latent_width = vae.config.sample_size
|
||||
self.latent_height = 128
|
||||
self.latent_width = 128
|
||||
|
||||
@staticmethod
|
||||
def _prepare_latent_image_ids(batch_size: int, height: int, width: int, device: torch.device, dtype: torch.dtype):
|
||||
# Build the same img_ids FluxPipeline.prepare_latents would use
|
||||
latent_image_ids = torch.zeros((height, width, 3), device=device, dtype=dtype)
|
||||
latent_image_ids[..., 1] = torch.arange(height, device=device)[:, None]
|
||||
latent_image_ids[..., 2] = torch.arange(width, device=device)[None, :]
|
||||
# reshape to [1, height*width, 3] then repeat for batch
|
||||
latent_image_ids = latent_image_ids.view(1, height * width, 3)
|
||||
return latent_image_ids.repeat(batch_size, 1, 1)
|
||||
|
||||
def sample(self, batch_size: int = 1, seed: int = 0) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
gen = torch.Generator(device=self.device).manual_seed(seed)
|
||||
|
||||
# 1) sample & pack the noise exactly as before
|
||||
shrunk = self.latent_channels // 4
|
||||
noise4d = torch.randn(
|
||||
(batch_size, shrunk, self.latent_height, self.latent_width),
|
||||
device=self.device,
|
||||
generator=gen,
|
||||
)
|
||||
latents = (
|
||||
noise4d.view(batch_size, shrunk, self.latent_height // 2, 2, self.latent_width // 2, 2)
|
||||
.permute(0, 2, 4, 1, 3, 5)
|
||||
.reshape(batch_size, (self.latent_height // 2) * (self.latent_width // 2), shrunk * 4)
|
||||
)
|
||||
|
||||
# 2) build the matching latent_image_ids
|
||||
latent_image_ids = self._prepare_latent_image_ids(
|
||||
batch_size,
|
||||
self.latent_height // 2,
|
||||
self.latent_width // 2,
|
||||
device=self.device,
|
||||
dtype=torch.long,
|
||||
)
|
||||
if latent_image_ids.ndim == 3 and latent_image_ids.shape[0] == 1:
|
||||
latent_image_ids = latent_image_ids[0] # [S_img , 3]
|
||||
|
||||
latent_image_ids = latent_image_ids.squeeze(0)
|
||||
|
||||
print(f"Sampled & packed latents: {latents.shape}")
|
||||
return latents, latent_image_ids
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# 4. Denoising Loop (uses bria_utils for σ schedule)
|
||||
# -----------------------------------------------------------------------------
|
||||
class BriaDenoise:
|
||||
def __init__(
|
||||
self,
|
||||
transformer: nn.Module,
|
||||
scheduler_name: str,
|
||||
device: torch.device,
|
||||
num_train_timesteps: int,
|
||||
num_inference_steps: int,
|
||||
**sched_kwargs,
|
||||
):
|
||||
self.transformer = transformer.to(device)
|
||||
self.device = device
|
||||
|
||||
# Build scheduler
|
||||
if scheduler_name == "flow_match":
|
||||
from diffusers import FlowMatchEulerDiscreteScheduler
|
||||
|
||||
self.scheduler = FlowMatchEulerDiscreteScheduler.from_config(transformer.config, **sched_kwargs)
|
||||
else:
|
||||
from diffusers import DDIMScheduler
|
||||
|
||||
self.scheduler = DDIMScheduler(**sched_kwargs)
|
||||
|
||||
# Use your exact σ schedule from bria_utils
|
||||
from bria_utils import get_original_sigmas
|
||||
|
||||
sigmas = get_original_sigmas(
|
||||
num_train_timesteps=num_train_timesteps,
|
||||
num_inference_steps=num_inference_steps,
|
||||
)
|
||||
self.scheduler.set_timesteps(
|
||||
num_inference_steps=None,
|
||||
timesteps=None,
|
||||
sigmas=sigmas,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# allow early exit
|
||||
self.interrupt = False
|
||||
# will be set in denoise()
|
||||
self._guidance_scale = 1.0
|
||||
self._joint_attention_kwargs = {}
|
||||
self.transformer = transformer.to(device)
|
||||
self.device = device
|
||||
|
||||
@property
|
||||
def guidance_scale(self) -> float:
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self) -> bool:
|
||||
return self.guidance_scale > 1.0
|
||||
|
||||
@property
|
||||
def joint_attention_kwargs(self) -> dict:
|
||||
return self._joint_attention_kwargs
|
||||
|
||||
@torch.no_grad()
|
||||
def denoise(
|
||||
self,
|
||||
latents: torch.Tensor, # [B, seq_len, C_hidden]
|
||||
latent_image_ids: torch.Tensor, # [B, seq_len, 3]
|
||||
prompt_embeds: torch.Tensor, # [B, S_text, D]
|
||||
negative_prompt_embeds: torch.Tensor, # [B, S_text, D]
|
||||
text_ids: torch.Tensor, # [B, S_text, 3]
|
||||
num_inference_steps: int = 30,
|
||||
guidance_scale: float = 5.0,
|
||||
normalize: bool = False,
|
||||
clip_value: float | None = None,
|
||||
seed: int = 0,
|
||||
) -> torch.Tensor:
|
||||
# 0) Quick cast & setup
|
||||
device = self.device
|
||||
# ensure dtype matches transformer
|
||||
target_dtype = next(self.transformer.parameters()).dtype
|
||||
latents = latents.to(device, dtype=target_dtype)
|
||||
prompt_embeds = prompt_embeds.to(device, dtype=target_dtype)
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(device, dtype=target_dtype)
|
||||
|
||||
# replicate reference encode_prompt behaviour
|
||||
if negative_prompt_embeds is None:
|
||||
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
|
||||
if guidance_scale > 1.0:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
self._guidance_scale = guidance_scale
|
||||
|
||||
# 1) Prepare FlowΓÇæMatch timesteps identical to reference pipeline
|
||||
if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler) and getattr(
|
||||
self.scheduler.config, "use_dynamic_shifting", False
|
||||
):
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
||||
image_seq_len = latents.shape[1]
|
||||
mu = calculate_shift(image_seq_len, 256, 16_384, 0.25, 0.75)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler, num_inference_steps, device, None, sigmas, mu=mu
|
||||
)
|
||||
else:
|
||||
sigmas = get_original_sigmas(
|
||||
num_train_timesteps=self.scheduler.config.num_train_timesteps,
|
||||
num_inference_steps=num_inference_steps,
|
||||
)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler, num_inference_steps, device, None, sigmas
|
||||
)
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
# 2) Loop with progress bar
|
||||
|
||||
with tqdm(total=num_inference_steps, desc="Denoising", unit="step") as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# a) expand for CFG?
|
||||
latent_model_input = torch.cat([latents] * 2, dim=0) if self.do_classifier_free_guidance else latents
|
||||
|
||||
# b) scale model input if needed
|
||||
if not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# c) broadcast timestep
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
# d) predict noise
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
txt_ids=text_ids,
|
||||
img_ids=latent_image_ids,
|
||||
)[0]
|
||||
|
||||
# e) classifierΓÇæfree guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
cfg_noise_pred_text = noise_pred_text.std()
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# f) optional normalize/clip
|
||||
if normalize:
|
||||
noise_pred = noise_pred * (0.7 * (cfg_noise_pred_text / noise_pred.std())) + 0.3 * noise_pred
|
||||
|
||||
if clip_value:
|
||||
noise_pred = noise_pred.clamp(-clip_value, clip_value)
|
||||
|
||||
# g) scheduler step, inΓÇæplace
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if (i + 1) % 5 == 0 or i == len(timesteps) - 1:
|
||||
progress_bar.update(5 if i + 1 < len(timesteps) else (len(timesteps) % 5))
|
||||
|
||||
# # j) XLA sync
|
||||
# if XLA_AVAILABLE:
|
||||
# xm.mark_step()
|
||||
|
||||
# 3) Return the final packed latents (still [B, seq_len, C_hidden])
|
||||
return latents
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# 5. Latents → Image
|
||||
# -----------------------------------------------------------------------------
|
||||
class BriaLatentsToImage:
|
||||
def __init__(self, vae: AutoencoderKL, device: torch.device):
|
||||
self.vae = vae.to(device)
|
||||
self.device = device
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self, latents: torch.Tensor) -> list[Image.Image]:
|
||||
"""
|
||||
Accepts either of the two packed shapes that come out of the denoiser
|
||||
|
||||
• [B , S , 16] – 3‑D, where S = H² (e.g. 16 384 for 1024×1024)
|
||||
ΓÇó [B , 1 , S , 16] ΓÇô 4ΓÇæD misΓÇæordered (what caused your crash)
|
||||
|
||||
Converts them to the VAEΓÇÖs expected shape [B , 4 , H , W] before decoding.
|
||||
"""
|
||||
# ---- 1. UnΓÇæpack to (B , 4 , H , W) ----------------------------------
|
||||
if latents.ndim == 3: # (B,S,16)
|
||||
B, S, C = latents.shape
|
||||
H2 = int(S**0.5) # 128 for 1024×1024
|
||||
latents = (
|
||||
latents.view(B, H2, H2, 4, 2, 2) # split channels into 4×(2×2)
|
||||
.permute(0, 3, 1, 4, 2, 5) # (B,4,H2,2,W2,2)
|
||||
.reshape(B, 4, H2 * 2, H2 * 2) # (B,4,H,W)
|
||||
)
|
||||
|
||||
elif latents.ndim == 4 and latents.shape[1] == 1: # (B,1,S,16)
|
||||
B, _, S, C = latents.shape
|
||||
H2 = int(S**0.5)
|
||||
latents = (
|
||||
latents.squeeze(1) # -> (B,S,16)
|
||||
.view(B, H2, H2, 4, 2, 2)
|
||||
.permute(0, 3, 1, 4, 2, 5)
|
||||
.reshape(B, 4, H2 * 2, H2 * 2)
|
||||
)
|
||||
# else: already (B,4,H,W)
|
||||
|
||||
# ---- 2. Standard VAE decode -----------------------------------------
|
||||
shift = 0 if self.vae.config.shift_factor is None else self.vae.config.shift_factor
|
||||
latents = (latents / self.vae.config.scaling_factor) + shift
|
||||
|
||||
# 1. temporarily move VAE to fp32 for the forward pass
|
||||
self.vae.to(dtype=torch.float32)
|
||||
images = self.vae.decode(latents.to(torch.float32)).sample # fullΓÇæprecision decode
|
||||
self.vae.to(dtype=torch.bfloat16) # cast to fp32 **after** decode
|
||||
images = (images.clamp(-1, 1) + 1) / 2 # [0,1] fp32
|
||||
images = (images.cpu().permute(0, 2, 3, 1).numpy() * 255).astype("uint8")
|
||||
|
||||
return [Image.fromarray(img) for img in images]
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Main: Assemble & Run
|
||||
# -----------------------------------------------------------------------------
|
||||
def main():
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print("Using device:", device)
|
||||
|
||||
# ΓöÇΓöÇΓöÇ Use your actual checkpoint locations ΓöÇΓöÇΓöÇ
|
||||
transformer_ckpt = "/home/ubuntu/invoke_local_nodes/bria_3_1/transformer"
|
||||
vae_ckpt = "/home/ubuntu/invoke_local_nodes/bria_3_1/vae"
|
||||
text_encoder_ckpt = "/home/ubuntu/invoke_local_nodes/bria_3_1/text_encoder"
|
||||
tokenizer_ckpt = "/home/ubuntu/invoke_local_nodes/bria_3_1/tokenizer"
|
||||
|
||||
# 1. Load models
|
||||
loader = BriaModelLoader(
|
||||
transformer_ckpt,
|
||||
vae_ckpt,
|
||||
text_encoder_ckpt,
|
||||
tokenizer_ckpt,
|
||||
device,
|
||||
)
|
||||
mdl = loader.get()
|
||||
# if diffusers.__version__ >= "0.27.0":
|
||||
# mdl["transformer"].enable_xformers_memory_efficient_attention() # now safe
|
||||
# else:
|
||||
# mdl["transformer"].disable_xformers_memory_efficient_attention() # keep quality
|
||||
|
||||
# 2. Encode prompt ΓÇö now capture text_ids as well
|
||||
text_enc = BriaTextEncoder(mdl["text_encoder"], mdl["tokenizer"], device)
|
||||
pos_embeds, neg_embeds, text_ids = text_enc.encode(
|
||||
prompt="3d rendered image, landscape made out of ice cream, rich ice cream textures, ice cream-valley , with a milky ice cream river, the ice cream has rich texture with visible chocolate chunks and intricate details, in the background an air balloon floats over the vally, in the sky visible dramatic like clouds, brown-chocolate color white and pink pallet, drama, beautiful surreal landscape, polarizing lens, very high contrast, 3d rendered realistic",
|
||||
negative_prompt=None,
|
||||
num_images_per_prompt=1,
|
||||
)
|
||||
|
||||
# 3. Sample initial noise → get both latents & latent_image_ids
|
||||
sampler = BriaLatentSampler(mdl["transformer"], mdl["vae"], device)
|
||||
init_latents, latent_image_ids = sampler.sample(batch_size=1, seed=1249141701)
|
||||
|
||||
# 4. Denoise ΓÇö now passing latent_image_ids and text_ids
|
||||
denoiser = BriaDenoise(
|
||||
transformer=mdl["transformer"],
|
||||
scheduler_name="flow_match",
|
||||
device=device,
|
||||
num_train_timesteps=1000,
|
||||
num_inference_steps=30,
|
||||
base_shift=0.5,
|
||||
max_shift=1.15,
|
||||
)
|
||||
final_latents = denoiser.denoise(
|
||||
init_latents,
|
||||
latent_image_ids,
|
||||
pos_embeds,
|
||||
neg_embeds,
|
||||
text_ids,
|
||||
num_inference_steps=30,
|
||||
guidance_scale=5.0,
|
||||
seed=1249141701,
|
||||
)
|
||||
|
||||
# 5. Decode
|
||||
decoder = BriaLatentsToImage(mdl["vae"], device)
|
||||
images = decoder.decode(final_latents)
|
||||
for i, img in enumerate(images):
|
||||
img.save(f"bria_output_{i}.png")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
320
invokeai/backend/bria/transformer_bria.py
Normal file
320
invokeai/backend/bria/transformer_bria.py
Normal file
@@ -0,0 +1,320 @@
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .bria_utils import FluxPosEmbed as EmbedND
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from diffusers.models.embeddings import TimestepEmbedding, get_timestep_embedding
|
||||
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from diffusers.models.normalization import AdaLayerNormContinuous
|
||||
from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
|
||||
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class Timesteps(nn.Module):
|
||||
def __init__(
|
||||
self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1, time_theta=10000
|
||||
):
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
self.flip_sin_to_cos = flip_sin_to_cos
|
||||
self.downscale_freq_shift = downscale_freq_shift
|
||||
self.scale = scale
|
||||
self.time_theta = time_theta
|
||||
|
||||
def forward(self, timesteps):
|
||||
t_emb = get_timestep_embedding(
|
||||
timesteps,
|
||||
self.num_channels,
|
||||
flip_sin_to_cos=self.flip_sin_to_cos,
|
||||
downscale_freq_shift=self.downscale_freq_shift,
|
||||
scale=self.scale,
|
||||
max_period=self.time_theta,
|
||||
)
|
||||
return t_emb
|
||||
|
||||
|
||||
class TimestepProjEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_dim, time_theta):
|
||||
super().__init__()
|
||||
|
||||
self.time_proj = Timesteps(
|
||||
num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, time_theta=time_theta
|
||||
)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
def forward(self, timestep, dtype):
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=dtype)) # (N, D)
|
||||
return timesteps_emb
|
||||
|
||||
|
||||
"""
|
||||
Based on FluxPipeline with several changes:
|
||||
- no pooled embeddings
|
||||
- We use zero padding for prompts
|
||||
- No guidance embedding since this is not a distilled version
|
||||
"""
|
||||
|
||||
|
||||
class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
"""
|
||||
The Transformer model introduced in Flux.
|
||||
|
||||
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
|
||||
|
||||
Parameters:
|
||||
patch_size (`int`): Patch size to turn the input data into small patches.
|
||||
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
|
||||
num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
|
||||
num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
|
||||
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
|
||||
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
|
||||
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
||||
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
|
||||
guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 1,
|
||||
in_channels: int = 64,
|
||||
num_layers: int = 19,
|
||||
num_single_layers: int = 38,
|
||||
attention_head_dim: int = 128,
|
||||
num_attention_heads: int = 24,
|
||||
joint_attention_dim: int = 4096,
|
||||
pooled_projection_dim: int = None,
|
||||
guidance_embeds: bool = False,
|
||||
axes_dims_rope: List[int] = [16, 56, 56],
|
||||
rope_theta=10000,
|
||||
time_theta=10000,
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = in_channels
|
||||
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
||||
|
||||
self.pos_embed = EmbedND(theta=rope_theta, axes_dim=axes_dims_rope)
|
||||
|
||||
self.time_embed = TimestepProjEmbeddings(embedding_dim=self.inner_dim, time_theta=time_theta)
|
||||
|
||||
# if pooled_projection_dim:
|
||||
# self.pooled_text_embed = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim=self.inner_dim, act_fn="silu")
|
||||
|
||||
if guidance_embeds:
|
||||
self.guidance_embed = TimestepProjEmbeddings(embedding_dim=self.inner_dim)
|
||||
|
||||
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
|
||||
self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
FluxTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=self.config.num_attention_heads,
|
||||
attention_head_dim=self.config.attention_head_dim,
|
||||
)
|
||||
for i in range(self.config.num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
FluxSingleTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=self.config.num_attention_heads,
|
||||
attention_head_dim=self.config.attention_head_dim,
|
||||
)
|
||||
for i in range(self.config.num_single_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
pooled_projections: torch.Tensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
img_ids: torch.Tensor = None,
|
||||
txt_ids: torch.Tensor = None,
|
||||
guidance: torch.Tensor = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
controlnet_block_samples=None,
|
||||
controlnet_single_block_samples=None,
|
||||
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
||||
"""
|
||||
The [`FluxTransformer2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
||||
Input `hidden_states`.
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
|
||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
||||
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
||||
from the embeddings of input conditions.
|
||||
timestep ( `torch.LongTensor`):
|
||||
Used to indicate denoising step.
|
||||
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
||||
A list of tensors that if specified are added to the residuals of transformer blocks.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
if joint_attention_kwargs is not None:
|
||||
joint_attention_kwargs = joint_attention_kwargs.copy()
|
||||
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
timestep = timestep.to(hidden_states.dtype)
|
||||
if guidance is not None:
|
||||
guidance = guidance.to(hidden_states.dtype)
|
||||
else:
|
||||
guidance = None
|
||||
|
||||
# temb = (
|
||||
# self.time_text_embed(timestep, pooled_projections)
|
||||
# if guidance is None
|
||||
# else self.time_text_embed(timestep, guidance, pooled_projections)
|
||||
# )
|
||||
|
||||
temb = self.time_embed(timestep, dtype=hidden_states.dtype)
|
||||
|
||||
# if pooled_projections:
|
||||
# temb+=self.pooled_text_embed(pooled_projections)
|
||||
|
||||
if guidance:
|
||||
temb += self.guidance_embed(guidance, dtype=hidden_states.dtype)
|
||||
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
if len(txt_ids.shape) == 2:
|
||||
ids = torch.cat((txt_ids, img_ids), dim=0)
|
||||
else:
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
# controlnet residual
|
||||
if controlnet_block_samples is not None:
|
||||
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
|
||||
interval_control = int(np.ceil(interval_control))
|
||||
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
|
||||
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
for index_block, block in enumerate(self.single_transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
# controlnet residual
|
||||
if controlnet_single_block_samples is not None:
|
||||
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
|
||||
interval_control = int(np.ceil(interval_control))
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
||||
+ controlnet_single_block_samples[index_block // interval_control]
|
||||
)
|
||||
|
||||
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
output = self.proj_out(hidden_states)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
@@ -80,7 +80,10 @@ class GenericDiffusersLoader(ModelLoader):
|
||||
"transformers",
|
||||
"invokeai.backend.quantization.fast_quantized_transformers_model",
|
||||
"invokeai.backend.quantization.fast_quantized_diffusion_model",
|
||||
"transformer_bria",
|
||||
]:
|
||||
if module == "transformer_bria":
|
||||
module = "invokeai.backend.bria.transformer_bria"
|
||||
res_type = sys.modules[module]
|
||||
else:
|
||||
res_type = sys.modules["diffusers"].pipelines
|
||||
|
||||
Reference in New Issue
Block a user