mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Compare commits
4 Commits
controlnet
...
ryan/flux-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
437d1087a2 | ||
|
|
a76a1244bc | ||
|
|
debc84cf95 | ||
|
|
4000dd1843 |
@@ -22,7 +22,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Condit
|
|||||||
title="FLUX Text Encoding",
|
title="FLUX Text Encoding",
|
||||||
tags=["prompt", "conditioning", "flux"],
|
tags=["prompt", "conditioning", "flux"],
|
||||||
category="conditioning",
|
category="conditioning",
|
||||||
version="1.1.0",
|
version="1.2.0",
|
||||||
classification=Classification.Prototype,
|
classification=Classification.Prototype,
|
||||||
)
|
)
|
||||||
class FluxTextEncoderInvocation(BaseInvocation):
|
class FluxTextEncoderInvocation(BaseInvocation):
|
||||||
@@ -41,6 +41,11 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
|||||||
t5_max_seq_len: Literal[256, 512] = InputField(
|
t5_max_seq_len: Literal[256, 512] = InputField(
|
||||||
description="Max sequence length for the T5 encoder. Expected to be 256 for FLUX schnell models and 512 for FLUX dev models."
|
description="Max sequence length for the T5 encoder. Expected to be 256 for FLUX schnell models and 512 for FLUX dev models."
|
||||||
)
|
)
|
||||||
|
use_short_t5_seq_len: bool = InputField(
|
||||||
|
description="Use a shorter sequence length for the T5 encoder if a short prompt is used. This can improve "
|
||||||
|
+ "performance and reduced peak memory, but may result in slightly different image outputs.",
|
||||||
|
default=True,
|
||||||
|
)
|
||||||
prompt: str = InputField(
|
prompt: str = InputField(
|
||||||
description="Text prompt to encode.",
|
description="Text prompt to encode.",
|
||||||
ui_component=UIComponent.Textarea,
|
ui_component=UIComponent.Textarea,
|
||||||
@@ -65,6 +70,10 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
|||||||
|
|
||||||
prompt = [self.prompt]
|
prompt = [self.prompt]
|
||||||
|
|
||||||
|
valid_seq_lens = [self.t5_max_seq_len]
|
||||||
|
if self.use_short_t5_seq_len:
|
||||||
|
valid_seq_lens = [128, 256, 512]
|
||||||
|
|
||||||
with (
|
with (
|
||||||
t5_text_encoder_info as t5_text_encoder,
|
t5_text_encoder_info as t5_text_encoder,
|
||||||
t5_tokenizer_info as t5_tokenizer,
|
t5_tokenizer_info as t5_tokenizer,
|
||||||
@@ -72,10 +81,10 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
|||||||
assert isinstance(t5_text_encoder, T5EncoderModel)
|
assert isinstance(t5_text_encoder, T5EncoderModel)
|
||||||
assert isinstance(t5_tokenizer, T5Tokenizer)
|
assert isinstance(t5_tokenizer, T5Tokenizer)
|
||||||
|
|
||||||
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)
|
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False)
|
||||||
|
|
||||||
context.util.signal_progress("Running T5 encoder")
|
context.util.signal_progress("Running T5 encoder")
|
||||||
prompt_embeds = t5_encoder(prompt)
|
prompt_embeds = t5_encoder(prompt, valid_seq_lens)
|
||||||
|
|
||||||
assert isinstance(prompt_embeds, torch.Tensor)
|
assert isinstance(prompt_embeds, torch.Tensor)
|
||||||
return prompt_embeds
|
return prompt_embeds
|
||||||
@@ -113,10 +122,10 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
|||||||
# There are currently no supported CLIP quantized models. Add support here if needed.
|
# There are currently no supported CLIP quantized models. Add support here if needed.
|
||||||
raise ValueError(f"Unsupported model format: {clip_text_encoder_config.format}")
|
raise ValueError(f"Unsupported model format: {clip_text_encoder_config.format}")
|
||||||
|
|
||||||
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77)
|
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True)
|
||||||
|
|
||||||
context.util.signal_progress("Running CLIP encoder")
|
context.util.signal_progress("Running CLIP encoder")
|
||||||
pooled_prompt_embeds = clip_encoder(prompt)
|
pooled_prompt_embeds = clip_encoder(prompt, [77])
|
||||||
|
|
||||||
assert isinstance(pooled_prompt_embeds, torch.Tensor)
|
assert isinstance(pooled_prompt_embeds, torch.Tensor)
|
||||||
return pooled_prompt_embeds
|
return pooled_prompt_embeds
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
|||||||
|
|
||||||
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||||
assert dim % 2 == 0
|
assert dim % 2 == 0
|
||||||
|
# scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim
|
||||||
scale = (
|
scale = (
|
||||||
torch.arange(0, dim, 2, dtype=torch.float32 if pos.device.type == "mps" else torch.float64, device=pos.device)
|
torch.arange(0, dim, 2, dtype=torch.float32 if pos.device.type == "mps" else torch.float64, device=pos.device)
|
||||||
/ dim
|
/ dim
|
||||||
@@ -24,12 +25,12 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
|||||||
out = torch.einsum("...n,d->...nd", pos, omega)
|
out = torch.einsum("...n,d->...nd", pos, omega)
|
||||||
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
||||||
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
||||||
return out.float()
|
return out.to(dtype=pos.dtype, device=pos.device)
|
||||||
|
|
||||||
|
|
||||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
|
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
|
||||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
xq_ = xq.view(*xq.shape[:-1], -1, 1, 2)
|
||||||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
xk_ = xk.view(*xk.shape[:-1], -1, 1, 2)
|
||||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
return xq_out.view(*xq.shape), xk_out.view(*xk.shape)
|
||||||
|
|||||||
@@ -1,32 +1,43 @@
|
|||||||
# Initially pulled from https://github.com/black-forest-labs/flux
|
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||||
|
|
||||||
|
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||||
|
|
||||||
|
|
||||||
class HFEncoder(nn.Module):
|
class HFEncoder(nn.Module):
|
||||||
def __init__(self, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer, is_clip: bool, max_length: int):
|
def __init__(self, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer, is_clip: bool):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.max_length = max_length
|
|
||||||
self.is_clip = is_clip
|
self.is_clip = is_clip
|
||||||
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
|
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.hf_module = encoder
|
self.hf_module = encoder
|
||||||
self.hf_module = self.hf_module.eval().requires_grad_(False)
|
self.hf_module = self.hf_module.eval().requires_grad_(False)
|
||||||
|
|
||||||
def forward(self, text: list[str]) -> Tensor:
|
def forward(self, text: list[str], valid_seq_lens: list[int]) -> Tensor:
|
||||||
|
valid_seq_lens = sorted(valid_seq_lens)
|
||||||
batch_encoding = self.tokenizer(
|
batch_encoding = self.tokenizer(
|
||||||
text,
|
text,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=self.max_length,
|
max_length=max(valid_seq_lens),
|
||||||
return_length=False,
|
return_length=True,
|
||||||
return_overflowing_tokens=False,
|
return_overflowing_tokens=False,
|
||||||
padding="max_length",
|
padding="max_length",
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
seq_len: int = batch_encoding["length"][0].item()
|
||||||
|
# Find selected_seq_len, the minimum valid sequence length that can contain all of the input tokens.
|
||||||
|
selected_seq_len = valid_seq_lens[-1]
|
||||||
|
for len in valid_seq_lens:
|
||||||
|
if len >= seq_len:
|
||||||
|
selected_seq_len = len
|
||||||
|
break
|
||||||
|
|
||||||
|
input_ids = batch_encoding["input_ids"][..., :selected_seq_len]
|
||||||
|
|
||||||
outputs = self.hf_module(
|
outputs = self.hf_module(
|
||||||
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
|
input_ids=input_ids.to(self.hf_module.device),
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
output_hidden_states=False,
|
output_hidden_states=False,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -66,10 +66,7 @@ class RMSNorm(torch.nn.Module):
|
|||||||
self.scale = nn.Parameter(torch.ones(dim))
|
self.scale = nn.Parameter(torch.ones(dim))
|
||||||
|
|
||||||
def forward(self, x: Tensor):
|
def forward(self, x: Tensor):
|
||||||
x_dtype = x.dtype
|
return torch.nn.functional.rms_norm(x, self.scale.shape, self.scale, eps=1e-6)
|
||||||
x = x.float()
|
|
||||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
|
||||||
return (x * rrms).to(dtype=x_dtype) * self.scale
|
|
||||||
|
|
||||||
|
|
||||||
class QKNorm(torch.nn.Module):
|
class QKNorm(torch.nn.Module):
|
||||||
|
|||||||
Reference in New Issue
Block a user