mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-21 10:08:10 -05:00
Compare commits
4 Commits
controlnet
...
ryan/flux-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
437d1087a2 | ||
|
|
a76a1244bc | ||
|
|
debc84cf95 | ||
|
|
4000dd1843 |
@@ -22,7 +22,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Condit
|
||||
title="FLUX Text Encoding",
|
||||
tags=["prompt", "conditioning", "flux"],
|
||||
category="conditioning",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxTextEncoderInvocation(BaseInvocation):
|
||||
@@ -41,6 +41,11 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
t5_max_seq_len: Literal[256, 512] = InputField(
|
||||
description="Max sequence length for the T5 encoder. Expected to be 256 for FLUX schnell models and 512 for FLUX dev models."
|
||||
)
|
||||
use_short_t5_seq_len: bool = InputField(
|
||||
description="Use a shorter sequence length for the T5 encoder if a short prompt is used. This can improve "
|
||||
+ "performance and reduced peak memory, but may result in slightly different image outputs.",
|
||||
default=True,
|
||||
)
|
||||
prompt: str = InputField(
|
||||
description="Text prompt to encode.",
|
||||
ui_component=UIComponent.Textarea,
|
||||
@@ -65,6 +70,10 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
|
||||
prompt = [self.prompt]
|
||||
|
||||
valid_seq_lens = [self.t5_max_seq_len]
|
||||
if self.use_short_t5_seq_len:
|
||||
valid_seq_lens = [128, 256, 512]
|
||||
|
||||
with (
|
||||
t5_text_encoder_info as t5_text_encoder,
|
||||
t5_tokenizer_info as t5_tokenizer,
|
||||
@@ -72,10 +81,10 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
assert isinstance(t5_text_encoder, T5EncoderModel)
|
||||
assert isinstance(t5_tokenizer, T5Tokenizer)
|
||||
|
||||
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)
|
||||
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False)
|
||||
|
||||
context.util.signal_progress("Running T5 encoder")
|
||||
prompt_embeds = t5_encoder(prompt)
|
||||
prompt_embeds = t5_encoder(prompt, valid_seq_lens)
|
||||
|
||||
assert isinstance(prompt_embeds, torch.Tensor)
|
||||
return prompt_embeds
|
||||
@@ -113,10 +122,10 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
# There are currently no supported CLIP quantized models. Add support here if needed.
|
||||
raise ValueError(f"Unsupported model format: {clip_text_encoder_config.format}")
|
||||
|
||||
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77)
|
||||
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True)
|
||||
|
||||
context.util.signal_progress("Running CLIP encoder")
|
||||
pooled_prompt_embeds = clip_encoder(prompt)
|
||||
pooled_prompt_embeds = clip_encoder(prompt, [77])
|
||||
|
||||
assert isinstance(pooled_prompt_embeds, torch.Tensor)
|
||||
return pooled_prompt_embeds
|
||||
|
||||
@@ -16,6 +16,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
||||
|
||||
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
assert dim % 2 == 0
|
||||
# scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim
|
||||
scale = (
|
||||
torch.arange(0, dim, 2, dtype=torch.float32 if pos.device.type == "mps" else torch.float64, device=pos.device)
|
||||
/ dim
|
||||
@@ -24,12 +25,12 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
out = torch.einsum("...n,d->...nd", pos, omega)
|
||||
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
||||
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
||||
return out.float()
|
||||
return out.to(dtype=pos.dtype, device=pos.device)
|
||||
|
||||
|
||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
|
||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_ = xq.view(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.view(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
return xq_out.view(*xq.shape), xk_out.view(*xk.shape)
|
||||
|
||||
@@ -1,32 +1,43 @@
|
||||
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||
|
||||
|
||||
from torch import Tensor, nn
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
|
||||
|
||||
class HFEncoder(nn.Module):
|
||||
def __init__(self, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer, is_clip: bool, max_length: int):
|
||||
def __init__(self, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer, is_clip: bool):
|
||||
super().__init__()
|
||||
self.max_length = max_length
|
||||
self.is_clip = is_clip
|
||||
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
|
||||
self.tokenizer = tokenizer
|
||||
self.hf_module = encoder
|
||||
self.hf_module = self.hf_module.eval().requires_grad_(False)
|
||||
|
||||
def forward(self, text: list[str]) -> Tensor:
|
||||
def forward(self, text: list[str], valid_seq_lens: list[int]) -> Tensor:
|
||||
valid_seq_lens = sorted(valid_seq_lens)
|
||||
batch_encoding = self.tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_length=False,
|
||||
max_length=max(valid_seq_lens),
|
||||
return_length=True,
|
||||
return_overflowing_tokens=False,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
seq_len: int = batch_encoding["length"][0].item()
|
||||
# Find selected_seq_len, the minimum valid sequence length that can contain all of the input tokens.
|
||||
selected_seq_len = valid_seq_lens[-1]
|
||||
for len in valid_seq_lens:
|
||||
if len >= seq_len:
|
||||
selected_seq_len = len
|
||||
break
|
||||
|
||||
input_ids = batch_encoding["input_ids"][..., :selected_seq_len]
|
||||
|
||||
outputs = self.hf_module(
|
||||
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
|
||||
input_ids=input_ids.to(self.hf_module.device),
|
||||
attention_mask=None,
|
||||
output_hidden_states=False,
|
||||
)
|
||||
|
||||
@@ -66,10 +66,7 @@ class RMSNorm(torch.nn.Module):
|
||||
self.scale = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
x_dtype = x.dtype
|
||||
x = x.float()
|
||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
||||
return (x * rrms).to(dtype=x_dtype) * self.scale
|
||||
return torch.nn.functional.rms_norm(x, self.scale.shape, self.scale, eps=1e-6)
|
||||
|
||||
|
||||
class QKNorm(torch.nn.Module):
|
||||
|
||||
Reference in New Issue
Block a user