mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-15 09:18:00 -05:00
Compare commits
23 Commits
v5.0.0.dev
...
ryan/flux-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fb5db32bb0 | ||
|
|
823c663e1b | ||
|
|
d40c9ff60a | ||
|
|
373b46867a | ||
|
|
dc66952491 | ||
|
|
1b80832b22 | ||
|
|
96b0450b20 | ||
|
|
45792cc152 | ||
|
|
f0baf880b5 | ||
|
|
a8a2fc106d | ||
|
|
d23ad1818d | ||
|
|
4181ab654b | ||
|
|
1c97360f9f | ||
|
|
74d6fceeb6 | ||
|
|
766ddc18dc | ||
|
|
e6ff7488a1 | ||
|
|
89a652cfcd | ||
|
|
b227b9059d | ||
|
|
3599a4a3e4 | ||
|
|
5dd619e137 | ||
|
|
7d447cbb88 | ||
|
|
3bbba7e4b1 | ||
|
|
b1845019fe |
135
invokeai/app/invocations/flux_text_encoder.py
Normal file
135
invokeai/app/invocations/flux_text_encoder.py
Normal file
@@ -0,0 +1,135 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
|
||||
from optimum.quanto import qfloat8
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import InputField
|
||||
from invokeai.app.invocations.flux_text_to_image import FLUX_MODELS, QuantizedModelForTextEncoding, TFluxModelKeys
|
||||
from invokeai.app.invocations.primitives import ConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_text_encoder",
|
||||
title="FLUX Text Encoding",
|
||||
tags=["image"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class FluxTextEncoderInvocation(BaseInvocation):
|
||||
model: TFluxModelKeys = InputField(description="The FLUX model to use for text-to-image generation.")
|
||||
use_8bit: bool = InputField(
|
||||
default=False, description="Whether to quantize the transformer model to 8-bit precision."
|
||||
)
|
||||
positive_prompt: str = InputField(description="Positive prompt for text-to-image generation.")
|
||||
|
||||
# TODO(ryand): Should we create a new return type for this invocation? This ConditioningOutput is clearly not
|
||||
# compatible with other ConditioningOutputs.
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
model_path = context.models.download_and_cache_model(FLUX_MODELS[self.model])
|
||||
|
||||
t5_embeddings, clip_embeddings = self._encode_prompt(context, model_path)
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
|
||||
)
|
||||
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
return ConditioningOutput.build(conditioning_name)
|
||||
|
||||
def _encode_prompt(self, context: InvocationContext, flux_model_dir: Path) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Determine the T5 max sequence length based on the model.
|
||||
if self.model == "flux-schnell":
|
||||
max_seq_len = 256
|
||||
# elif self.model == "flux-dev":
|
||||
# max_seq_len = 512
|
||||
else:
|
||||
raise ValueError(f"Unknown model: {self.model}")
|
||||
|
||||
# Load the CLIP tokenizer.
|
||||
clip_tokenizer_path = flux_model_dir / "tokenizer"
|
||||
clip_tokenizer = CLIPTokenizer.from_pretrained(clip_tokenizer_path, local_files_only=True)
|
||||
assert isinstance(clip_tokenizer, CLIPTokenizer)
|
||||
|
||||
# Load the T5 tokenizer.
|
||||
t5_tokenizer_path = flux_model_dir / "tokenizer_2"
|
||||
t5_tokenizer = T5TokenizerFast.from_pretrained(t5_tokenizer_path, local_files_only=True)
|
||||
assert isinstance(t5_tokenizer, T5TokenizerFast)
|
||||
|
||||
clip_text_encoder_path = flux_model_dir / "text_encoder"
|
||||
t5_text_encoder_path = flux_model_dir / "text_encoder_2"
|
||||
with (
|
||||
context.models.load_local_model(
|
||||
model_path=clip_text_encoder_path, loader=self._load_flux_text_encoder
|
||||
) as clip_text_encoder,
|
||||
context.models.load_local_model(
|
||||
model_path=t5_text_encoder_path, loader=self._load_flux_text_encoder_2
|
||||
) as t5_text_encoder,
|
||||
):
|
||||
assert isinstance(clip_text_encoder, CLIPTextModel)
|
||||
assert isinstance(t5_text_encoder, T5EncoderModel)
|
||||
pipeline = FluxPipeline(
|
||||
scheduler=None,
|
||||
vae=None,
|
||||
text_encoder=clip_text_encoder,
|
||||
tokenizer=clip_tokenizer,
|
||||
text_encoder_2=t5_text_encoder,
|
||||
tokenizer_2=t5_tokenizer,
|
||||
transformer=None,
|
||||
)
|
||||
|
||||
# prompt_embeds: T5 embeddings
|
||||
# pooled_prompt_embeds: CLIP embeddings
|
||||
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
|
||||
prompt=self.positive_prompt,
|
||||
prompt_2=self.positive_prompt,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
max_sequence_length=max_seq_len,
|
||||
)
|
||||
|
||||
assert isinstance(prompt_embeds, torch.Tensor)
|
||||
assert isinstance(pooled_prompt_embeds, torch.Tensor)
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
@staticmethod
|
||||
def _load_flux_text_encoder(path: Path) -> CLIPTextModel:
|
||||
model = CLIPTextModel.from_pretrained(path, local_files_only=True)
|
||||
assert isinstance(model, CLIPTextModel)
|
||||
return model
|
||||
|
||||
def _load_flux_text_encoder_2(self, path: Path) -> T5EncoderModel:
|
||||
if self.use_8bit:
|
||||
model_8bit_path = path / "quantized"
|
||||
if model_8bit_path.exists():
|
||||
# The quantized model exists, load it.
|
||||
# TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like
|
||||
# something that we should be able to make much faster.
|
||||
q_model = QuantizedModelForTextEncoding.from_pretrained(model_8bit_path)
|
||||
|
||||
# Access the underlying wrapped model.
|
||||
# We access the wrapped model, even though it is private, because it simplifies the type checking by
|
||||
# always returning a T5EncoderModel from this function.
|
||||
model = q_model._wrapped
|
||||
else:
|
||||
# The quantized model does not exist yet, quantize and save it.
|
||||
# TODO(ryand): dtype?
|
||||
model = T5EncoderModel.from_pretrained(path, local_files_only=True)
|
||||
assert isinstance(model, T5EncoderModel)
|
||||
|
||||
q_model = QuantizedModelForTextEncoding.quantize(model, weights=qfloat8)
|
||||
|
||||
model_8bit_path.mkdir(parents=True, exist_ok=True)
|
||||
q_model.save_pretrained(model_8bit_path)
|
||||
|
||||
# (See earlier comment about accessing the wrapped model.)
|
||||
model = q_model._wrapped
|
||||
else:
|
||||
model = T5EncoderModel.from_pretrained(path, local_files_only=True)
|
||||
|
||||
assert isinstance(model, T5EncoderModel)
|
||||
return model
|
||||
255
invokeai/app/invocations/flux_text_to_image.py
Normal file
255
invokeai/app/invocations/flux_text_to_image.py
Normal file
@@ -0,0 +1,255 @@
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import accelerate
|
||||
import torch
|
||||
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
||||
from einops import rearrange, repeat
|
||||
from flux.model import Flux
|
||||
from flux.modules.autoencoder import AutoEncoder
|
||||
from flux.sampling import denoise, get_noise, get_schedule, unpack
|
||||
from flux.util import configs as flux_configs
|
||||
from PIL import Image
|
||||
from safetensors.torch import load_file
|
||||
from transformers.models.auto import AutoModelForTextEncoding
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
ConditioningField,
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
|
||||
from invokeai.backend.quantization.fast_quantized_diffusion_model import FastQuantizedDiffusersModel
|
||||
from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
TFluxModelKeys = Literal["flux-schnell"]
|
||||
FLUX_MODELS: dict[TFluxModelKeys, str] = {"flux-schnell": "black-forest-labs/FLUX.1-schnell"}
|
||||
|
||||
|
||||
class QuantizedFluxTransformer2DModel(FastQuantizedDiffusersModel):
|
||||
base_class = FluxTransformer2DModel
|
||||
|
||||
|
||||
class QuantizedModelForTextEncoding(FastQuantizedTransformersModel):
|
||||
auto_class = AutoModelForTextEncoding
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_text_to_image",
|
||||
title="FLUX Text to Image",
|
||||
tags=["image"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Text-to-image generation using a FLUX model."""
|
||||
|
||||
model: TFluxModelKeys = InputField(description="The FLUX model to use for text-to-image generation.")
|
||||
quantization_type: Literal["raw", "NF4", "llm_int8"] = InputField(
|
||||
default="raw", description="The type of quantization to use for the transformer model."
|
||||
)
|
||||
use_8bit: bool = InputField(
|
||||
default=False, description="Whether to quantize the transformer model to 8-bit precision."
|
||||
)
|
||||
positive_text_conditioning: ConditioningField = InputField(
|
||||
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||
)
|
||||
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
|
||||
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
|
||||
num_steps: int = InputField(default=4, description="Number of diffusion steps.")
|
||||
guidance: float = InputField(
|
||||
default=4.0,
|
||||
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images.",
|
||||
)
|
||||
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
# model_path = context.models.download_and_cache_model(FLUX_MODELS[self.model])
|
||||
flux_transformer_path = context.models.download_and_cache_model(
|
||||
"https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/flux1-schnell.safetensors"
|
||||
)
|
||||
flux_ae_path = context.models.download_and_cache_model(
|
||||
"https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/ae.safetensors"
|
||||
)
|
||||
|
||||
# Load the conditioning data.
|
||||
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
|
||||
assert len(cond_data.conditionings) == 1
|
||||
flux_conditioning = cond_data.conditionings[0]
|
||||
assert isinstance(flux_conditioning, FLUXConditioningInfo)
|
||||
|
||||
latents = self._run_diffusion(
|
||||
context, flux_transformer_path, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds
|
||||
)
|
||||
image = self._run_vae_decoding(context, flux_ae_path, latents)
|
||||
image_dto = context.images.save(image=image)
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
def _run_diffusion(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
flux_transformer_path: Path,
|
||||
clip_embeddings: torch.Tensor,
|
||||
t5_embeddings: torch.Tensor,
|
||||
):
|
||||
inference_dtype = TorchDevice.choose_torch_dtype()
|
||||
|
||||
# Prepare input noise.
|
||||
# TODO(ryand): Does the seed behave the same on different devices? Should we re-implement this to always use a
|
||||
# CPU RNG?
|
||||
x = get_noise(
|
||||
num_samples=1,
|
||||
height=self.height,
|
||||
width=self.width,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
dtype=inference_dtype,
|
||||
seed=self.seed,
|
||||
)
|
||||
|
||||
img, img_ids = self._prepare_latent_img_patches(x)
|
||||
|
||||
# HACK(ryand): Find a better way to determine if this is a schnell model or not.
|
||||
is_schnell = "schnell" in str(flux_transformer_path)
|
||||
timesteps = get_schedule(
|
||||
num_steps=self.num_steps,
|
||||
image_seq_len=img.shape[1],
|
||||
shift=not is_schnell,
|
||||
)
|
||||
|
||||
bs, t5_seq_len, _ = t5_embeddings.shape
|
||||
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
|
||||
|
||||
# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
|
||||
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
|
||||
# if the cache is not empty.
|
||||
context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
|
||||
|
||||
with context.models.load_local_model(
|
||||
model_path=flux_transformer_path, loader=self._load_flux_transformer
|
||||
) as transformer:
|
||||
assert isinstance(transformer, Flux)
|
||||
|
||||
x = denoise(
|
||||
model=transformer,
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=t5_embeddings,
|
||||
txt_ids=txt_ids,
|
||||
vec=clip_embeddings,
|
||||
timesteps=timesteps,
|
||||
guidance=self.guidance,
|
||||
)
|
||||
|
||||
x = unpack(x.float(), self.height, self.width)
|
||||
|
||||
return x
|
||||
|
||||
def _prepare_latent_img_patches(self, latent_img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Convert an input image in latent space to patches for diffusion.
|
||||
|
||||
This implementation was extracted from:
|
||||
https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/sampling.py#L32
|
||||
|
||||
Returns:
|
||||
tuple[Tensor, Tensor]: (img, img_ids), as defined in the original flux repo.
|
||||
"""
|
||||
bs, c, h, w = latent_img.shape
|
||||
|
||||
# Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches.
|
||||
img = rearrange(latent_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
if img.shape[0] == 1 and bs > 1:
|
||||
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
||||
|
||||
# Generate patch position ids.
|
||||
img_ids = torch.zeros(h // 2, w // 2, 3)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
img_ids = img_ids.to(latent_img.device)
|
||||
|
||||
return img, img_ids
|
||||
|
||||
def _run_vae_decoding(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
flux_ae_path: Path,
|
||||
latents: torch.Tensor,
|
||||
) -> Image.Image:
|
||||
with context.models.load_local_model(model_path=flux_ae_path, loader=self._load_flux_vae) as vae:
|
||||
assert isinstance(vae, AutoEncoder)
|
||||
# TODO(ryand): Test that this works with both float16 and bfloat16.
|
||||
with torch.autocast(device_type=latents.device.type, dtype=TorchDevice.choose_torch_dtype()):
|
||||
img = vae.decode(latents)
|
||||
|
||||
img.clamp(-1, 1)
|
||||
img = rearrange(img[0], "c h w -> h w c")
|
||||
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
|
||||
|
||||
return img_pil
|
||||
|
||||
def _load_flux_transformer(self, path: Path) -> FluxTransformer2DModel:
|
||||
inference_dtype = TorchDevice.choose_torch_dtype()
|
||||
if self.quantization_type == "raw":
|
||||
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
|
||||
params = flux_configs["flux-schnell"].params
|
||||
|
||||
# Initialize the model on the "meta" device.
|
||||
with accelerate.init_empty_weights():
|
||||
model = Flux(params).to(inference_dtype)
|
||||
|
||||
state_dict = load_file(path)
|
||||
# TODO(ryand): Cast the state_dict to the appropriate dtype?
|
||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||
elif self.quantization_type == "NF4":
|
||||
model_path = path.parent / "bnb_nf4.safetensors"
|
||||
|
||||
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
|
||||
params = flux_configs["flux-schnell"].params
|
||||
# Initialize the model on the "meta" device.
|
||||
with accelerate.init_empty_weights():
|
||||
model = Flux(params)
|
||||
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
|
||||
|
||||
# TODO(ryand): Right now, some of the weights are loaded in bfloat16. Think about how best to handle
|
||||
# this on GPUs without bfloat16 support.
|
||||
state_dict = load_file(model_path)
|
||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||
|
||||
elif self.quantization_type == "llm_int8":
|
||||
raise NotImplementedError("LLM int8 quantization is not yet supported.")
|
||||
# model_config = FluxTransformer2DModel.load_config(path, local_files_only=True)
|
||||
# with accelerate.init_empty_weights():
|
||||
# empty_model = FluxTransformer2DModel.from_config(model_config)
|
||||
# assert isinstance(empty_model, FluxTransformer2DModel)
|
||||
# model_int8_path = path / "bnb_llm_int8"
|
||||
# assert model_int8_path.exists()
|
||||
# with accelerate.init_empty_weights():
|
||||
# model = quantize_model_llm_int8(empty_model, modules_to_not_convert=set())
|
||||
|
||||
# sd = load_file(model_int8_path / "model.safetensors")
|
||||
# model.load_state_dict(sd, strict=True, assign=True)
|
||||
else:
|
||||
raise ValueError(f"Unsupported quantization type: {self.quantization_type}")
|
||||
|
||||
assert isinstance(model, Flux)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _load_flux_vae(path: Path) -> AutoEncoder:
|
||||
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
|
||||
ae_params = flux_configs["flux-schnell"].ae_params
|
||||
with accelerate.init_empty_weights():
|
||||
ae = AutoEncoder(ae_params)
|
||||
|
||||
state_dict = load_file(path)
|
||||
ae.load_state_dict(state_dict, strict=True, assign=True)
|
||||
return ae
|
||||
517
invokeai/backend/bnb.py
Normal file
517
invokeai/backend/bnb.py
Normal file
@@ -0,0 +1,517 @@
|
||||
from typing import Any, Optional, Set, Type
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
|
||||
# The utils in this file take ideas from
|
||||
# https://github.com/Lightning-AI/pytorch-lightning/blob/1551a16b94f5234a4a78801098f64d0732ef5cb5/src/lightning/fabric/plugins/precision/bitsandbytes.py
|
||||
|
||||
|
||||
# Patterns:
|
||||
# - Quantize:
|
||||
# - Initialize model on meta device
|
||||
# - Replace layers
|
||||
# - Load state_dict to cpu
|
||||
# - Load state_dict into model
|
||||
# - Quantize on GPU
|
||||
# - Extract state_dict
|
||||
# - Save
|
||||
|
||||
# - Load:
|
||||
# - Initialize model on meta device
|
||||
# - Replace layers
|
||||
# - Load state_dict to cpu
|
||||
# - Load state_dict into model on cpu
|
||||
# - Move to GPU
|
||||
|
||||
|
||||
# class InvokeInt8Params(bnb.nn.Int8Params):
|
||||
# """Overrides `bnb.nn.Int8Params` to add the following functionality:
|
||||
# - Make it possible to load a quantized state dict without putting the weight on a "cuda" device.
|
||||
# """
|
||||
|
||||
# def quantize(self, device: Optional[torch.device] = None):
|
||||
# device = device or torch.device("cuda")
|
||||
# if device.type != "cuda":
|
||||
# raise RuntimeError(f"Int8Params quantization is only supported on CUDA devices ({device=}).")
|
||||
|
||||
# # https://github.com/TimDettmers/bitsandbytes/blob/0.41.0/bitsandbytes/nn/modules.py#L291-L302
|
||||
# B = self.data.contiguous().half().cuda(device)
|
||||
# if self.has_fp16_weights:
|
||||
# self.data = B
|
||||
# else:
|
||||
# # we store the 8-bit rows-major weight
|
||||
# # we convert this weight to the turning/ampere weight during the first inference pass
|
||||
# CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
|
||||
# del CBt
|
||||
# del SCBt
|
||||
# self.data = CB
|
||||
# self.CB = CB
|
||||
# self.SCB = SCB
|
||||
|
||||
|
||||
class Invoke2Linear8bitLt(torch.nn.Linear):
|
||||
"""This class is the base module for the [LLM.int8()](https://arxiv.org/abs/2208.07339) algorithm."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_features: int,
|
||||
output_features: int,
|
||||
bias=True,
|
||||
has_fp16_weights=True,
|
||||
memory_efficient_backward=False,
|
||||
threshold=0.0,
|
||||
index=None,
|
||||
device=None,
|
||||
):
|
||||
"""
|
||||
Initialize Linear8bitLt class.
|
||||
|
||||
Args:
|
||||
input_features (`int`):
|
||||
Number of input features of the linear layer.
|
||||
output_features (`int`):
|
||||
Number of output features of the linear layer.
|
||||
bias (`bool`, defaults to `True`):
|
||||
Whether the linear class uses the bias term as well.
|
||||
"""
|
||||
super().__init__(input_features, output_features, bias, device)
|
||||
assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
|
||||
self.state = bnb.MatmulLtState()
|
||||
self.index = index
|
||||
|
||||
self.state.threshold = threshold
|
||||
self.state.has_fp16_weights = has_fp16_weights
|
||||
self.state.memory_efficient_backward = memory_efficient_backward
|
||||
if threshold > 0.0 and not has_fp16_weights:
|
||||
self.state.use_pool = True
|
||||
|
||||
self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights)
|
||||
self._register_load_state_dict_pre_hook(maybe_rearrange_weight)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
super()._save_to_state_dict(destination, prefix, keep_vars)
|
||||
|
||||
# we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data
|
||||
scb_name = "SCB"
|
||||
|
||||
# case 1: .cuda was called, SCB is in self.weight
|
||||
param_from_weight = getattr(self.weight, scb_name)
|
||||
# case 2: self.init_8bit_state was called, SCB is in self.state
|
||||
param_from_state = getattr(self.state, scb_name)
|
||||
# case 3: SCB is in self.state, weight layout reordered after first forward()
|
||||
layout_reordered = self.state.CxB is not None
|
||||
|
||||
key_name = prefix + f"{scb_name}"
|
||||
format_name = prefix + "weight_format"
|
||||
|
||||
if not self.state.has_fp16_weights:
|
||||
if param_from_weight is not None:
|
||||
destination[key_name] = param_from_weight if keep_vars else param_from_weight.detach()
|
||||
destination[format_name] = torch.tensor(0, dtype=torch.uint8)
|
||||
elif param_from_state is not None and not layout_reordered:
|
||||
destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
|
||||
destination[format_name] = torch.tensor(0, dtype=torch.uint8)
|
||||
elif param_from_state is not None:
|
||||
destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
|
||||
weights_format = self.state.formatB
|
||||
# At this point `weights_format` is an str
|
||||
if weights_format not in LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING:
|
||||
raise ValueError(f"Unrecognized weights format {weights_format}")
|
||||
|
||||
weights_format = LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING[weights_format]
|
||||
|
||||
destination[format_name] = torch.tensor(weights_format, dtype=torch.uint8)
|
||||
|
||||
def _load_from_state_dict(
|
||||
self,
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
):
|
||||
super()._load_from_state_dict(
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
)
|
||||
unexpected_copy = list(unexpected_keys)
|
||||
|
||||
for key in unexpected_copy:
|
||||
input_name = key[len(prefix) :]
|
||||
if input_name == "SCB":
|
||||
if self.weight.SCB is None:
|
||||
# buffers not yet initialized, can't access them directly without quantizing first
|
||||
raise RuntimeError(
|
||||
"Loading a quantized checkpoint into non-quantized Linear8bitLt is "
|
||||
"not supported. Please call module.cuda() before module.load_state_dict()",
|
||||
)
|
||||
|
||||
input_param = state_dict[key]
|
||||
self.weight.SCB.copy_(input_param)
|
||||
|
||||
if self.state.SCB is not None:
|
||||
self.state.SCB = self.weight.SCB
|
||||
|
||||
unexpected_keys.remove(key)
|
||||
|
||||
def init_8bit_state(self):
|
||||
self.state.CB = self.weight.CB
|
||||
self.state.SCB = self.weight.SCB
|
||||
self.weight.CB = None
|
||||
self.weight.SCB = None
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
self.state.is_training = self.training
|
||||
if self.weight.CB is not None:
|
||||
self.init_8bit_state()
|
||||
|
||||
# weights are cast automatically as Int8Params, but the bias has to be cast manually
|
||||
if self.bias is not None and self.bias.dtype != x.dtype:
|
||||
self.bias.data = self.bias.data.to(x.dtype)
|
||||
|
||||
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
|
||||
|
||||
if not self.state.has_fp16_weights:
|
||||
if self.state.CB is not None and self.state.CxB is not None:
|
||||
# we converted 8-bit row major to turing/ampere format in the first inference pass
|
||||
# we no longer need the row-major weight
|
||||
del self.state.CB
|
||||
self.weight.data = self.state.CxB
|
||||
return out
|
||||
|
||||
|
||||
class InvokeLinear8bitLt(bnb.nn.Linear8bitLt):
|
||||
"""Wraps `bnb.nn.Linear8bitLt` and adds the following functionality:
|
||||
- enables instantiation directly on the device
|
||||
- re-quantizaton when loading the state dict
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, *args: Any, device: Optional[torch.device] = None, threshold: float = 6.0, **kwargs: Any
|
||||
) -> None:
|
||||
super().__init__(*args, device=device, threshold=threshold, **kwargs)
|
||||
# If the device is CUDA or we are under a CUDA context manager, quantize the weight here, so we don't end up
|
||||
# filling the device memory with float32 weights which could lead to OOM
|
||||
# if torch.tensor(0, device=device).device.type == "cuda":
|
||||
# self.quantize_()
|
||||
# self._register_load_state_dict_pre_hook(partial(_quantize_on_load_hook, self.quantize_))
|
||||
# self.register_load_state_dict_post_hook(_ignore_missing_weights_hook)
|
||||
|
||||
def _load_from_state_dict(
|
||||
self,
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
):
|
||||
super()._load_from_state_dict(
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
)
|
||||
unexpected_copy = list(unexpected_keys)
|
||||
|
||||
for key in unexpected_copy:
|
||||
input_name = key[len(prefix) :]
|
||||
if input_name == "SCB":
|
||||
if self.weight.SCB is None:
|
||||
# buffers not yet initialized, can't access them directly without quantizing first
|
||||
raise RuntimeError(
|
||||
"Loading a quantized checkpoint into non-quantized Linear8bitLt is "
|
||||
"not supported. Please call module.cuda() before module.load_state_dict()",
|
||||
)
|
||||
|
||||
input_param = state_dict[key]
|
||||
self.weight.SCB.copy_(input_param)
|
||||
|
||||
if self.state.SCB is not None:
|
||||
self.state.SCB = self.weight.SCB
|
||||
|
||||
unexpected_keys.remove(key)
|
||||
|
||||
def quantize_(self, weight: Optional[torch.Tensor] = None, device: Optional[torch.device] = None) -> None:
|
||||
"""Inplace quantize."""
|
||||
if weight is None:
|
||||
weight = self.weight.data
|
||||
if weight.data.dtype == torch.int8:
|
||||
# already quantized
|
||||
return
|
||||
assert isinstance(self.weight, bnb.nn.Int8Params)
|
||||
self.weight = self.quantize(self.weight, weight, device)
|
||||
|
||||
@staticmethod
|
||||
def quantize(
|
||||
int8params: bnb.nn.Int8Params, weight: torch.Tensor, device: Optional[torch.device]
|
||||
) -> bnb.nn.Int8Params:
|
||||
device = device or torch.device("cuda")
|
||||
if device.type != "cuda":
|
||||
raise RuntimeError(f"Unexpected device type: {device.type}")
|
||||
# https://github.com/TimDettmers/bitsandbytes/blob/0.41.0/bitsandbytes/nn/modules.py#L291-L302
|
||||
B = weight.contiguous().to(device=device, dtype=torch.float16)
|
||||
if int8params.has_fp16_weights:
|
||||
int8params.data = B
|
||||
else:
|
||||
CB, CBt, SCB, SCBt, _ = bnb.functional.double_quant(B)
|
||||
del CBt
|
||||
del SCBt
|
||||
int8params.data = CB
|
||||
int8params.CB = CB
|
||||
int8params.SCB = SCB
|
||||
return int8params
|
||||
|
||||
|
||||
# class _Linear4bit(bnb.nn.Linear4bit):
|
||||
# """Wraps `bnb.nn.Linear4bit` to enable: instantiation directly on the device, re-quantizaton when loading the
|
||||
# state dict, meta-device initialization, and materialization."""
|
||||
|
||||
# def __init__(self, *args: Any, device: Optional[torch.device] = None, **kwargs: Any) -> None:
|
||||
# super().__init__(*args, device=device, **kwargs)
|
||||
# self.weight = cast(bnb.nn.Params4bit, self.weight) # type: ignore[has-type]
|
||||
# self.bias = cast(Optional[torch.nn.Parameter], self.bias) # type: ignore[has-type]
|
||||
# # if the device is CUDA or we are under a CUDA context manager, quantize the weight here, so we don't end up
|
||||
# # filling the device memory with float32 weights which could lead to OOM
|
||||
# if torch.tensor(0, device=device).device.type == "cuda":
|
||||
# self.quantize_()
|
||||
# self._register_load_state_dict_pre_hook(partial(_quantize_on_load_hook, self.quantize_))
|
||||
# self.register_load_state_dict_post_hook(_ignore_missing_weights_hook)
|
||||
|
||||
# def quantize_(self, weight: Optional[torch.Tensor] = None, device: Optional[torch.device] = None) -> None:
|
||||
# """Inplace quantize."""
|
||||
# if weight is None:
|
||||
# weight = self.weight.data
|
||||
# if weight.data.dtype == torch.uint8:
|
||||
# # already quantized
|
||||
# return
|
||||
# assert isinstance(self.weight, bnb.nn.Params4bit)
|
||||
# self.weight = self.quantize(self.weight, weight, device)
|
||||
|
||||
# @staticmethod
|
||||
# def quantize(
|
||||
# params4bit: bnb.nn.Params4bit, weight: torch.Tensor, device: Optional[torch.device]
|
||||
# ) -> bnb.nn.Params4bit:
|
||||
# device = device or torch.device("cuda")
|
||||
# if device.type != "cuda":
|
||||
# raise RuntimeError(f"Unexpected device type: {device.type}")
|
||||
# # https://github.com/TimDettmers/bitsandbytes/blob/0.41.0/bitsandbytes/nn/modules.py#L156-L159
|
||||
# w = weight.contiguous().to(device=device, dtype=torch.half)
|
||||
# w_4bit, quant_state = bnb.functional.quantize_4bit(
|
||||
# w,
|
||||
# blocksize=params4bit.blocksize,
|
||||
# compress_statistics=params4bit.compress_statistics,
|
||||
# quant_type=params4bit.quant_type,
|
||||
# )
|
||||
# return _replace_param(params4bit, w_4bit, quant_state)
|
||||
|
||||
# def to_empty(self, *, device: _DEVICE, recurse: bool = True) -> Self:
|
||||
# if self.weight.dtype == torch.uint8: # was quantized
|
||||
# # cannot init the quantized params directly
|
||||
# weight = torch.empty(self.weight.quant_state.shape, device=device, dtype=torch.half)
|
||||
# else:
|
||||
# weight = torch.empty_like(self.weight.data, device=device)
|
||||
# device = torch.device(device)
|
||||
# if device.type == "cuda": # re-quantize
|
||||
# self.quantize_(weight, device)
|
||||
# else:
|
||||
# self.weight = _replace_param(self.weight, weight)
|
||||
# if self.bias is not None:
|
||||
# self.bias = _replace_param(self.bias, torch.empty_like(self.bias, device=device))
|
||||
# return self
|
||||
|
||||
|
||||
def convert_model_to_bnb_llm_int8(model: torch.nn.Module, ignore_modules: set[str]):
|
||||
linear_cls = InvokeLinear8bitLt
|
||||
_convert_linear_layers(model, linear_cls, ignore_modules)
|
||||
|
||||
# TODO(ryand): Is this necessary?
|
||||
# set the compute dtype if necessary
|
||||
# for m in model.modules():
|
||||
# if isinstance(m, bnb.nn.Linear4bit):
|
||||
# m.compute_dtype = self.dtype
|
||||
# m.compute_type_is_set = False
|
||||
|
||||
|
||||
# class BitsandbytesPrecision(Precision):
|
||||
# """Plugin for quantizing weights with `bitsandbytes <https://github.com/TimDettmers/bitsandbytes>`__.
|
||||
|
||||
# .. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature.
|
||||
|
||||
# .. note::
|
||||
# The optimizer is not automatically replaced with ``bitsandbytes.optim.Adam8bit`` or equivalent 8-bit optimizers.
|
||||
|
||||
# Args:
|
||||
# mode: The quantization mode to use.
|
||||
# dtype: The compute dtype to use.
|
||||
# ignore_modules: The submodules whose Linear layers should not be replaced, for example. ``{"lm_head"}``.
|
||||
# This might be desirable for numerical stability. The string will be checked in as a prefix, so a value like
|
||||
# "transformer.blocks" will ignore all linear layers in all of the transformer blocks.
|
||||
# """
|
||||
|
||||
# def __init__(
|
||||
# self,
|
||||
# mode: Literal["nf4", "nf4-dq", "fp4", "fp4-dq", "int8", "int8-training"],
|
||||
# dtype: Optional[torch.dtype] = None,
|
||||
# ignore_modules: Optional[Set[str]] = None,
|
||||
# ) -> None:
|
||||
# if dtype is None:
|
||||
# # try to be smart about the default selection
|
||||
# if mode.startswith("int8"):
|
||||
# dtype = torch.float16
|
||||
# else:
|
||||
# dtype = (
|
||||
# torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
|
||||
# )
|
||||
# if mode.startswith("int8") and dtype is not torch.float16:
|
||||
# # this limitation is mentioned in https://huggingface.co/blog/hf-bitsandbytes-integration#usage
|
||||
# raise ValueError(f"{mode!r} only works with `dtype=torch.float16`, but you chose `{dtype}`")
|
||||
|
||||
# globals_ = globals()
|
||||
# mode_to_cls = {
|
||||
# "nf4": globals_["_NF4Linear"],
|
||||
# "nf4-dq": globals_["_NF4DQLinear"],
|
||||
# "fp4": globals_["_FP4Linear"],
|
||||
# "fp4-dq": globals_["_FP4DQLinear"],
|
||||
# "int8-training": globals_["_Linear8bitLt"],
|
||||
# "int8": globals_["_Int8LinearInference"],
|
||||
# }
|
||||
# self._linear_cls = mode_to_cls[mode]
|
||||
# self.dtype = dtype
|
||||
# self.ignore_modules = ignore_modules or set()
|
||||
|
||||
# @override
|
||||
# def convert_module(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||
# # avoid naive users thinking they quantized their model
|
||||
# if not any(isinstance(m, torch.nn.Linear) for m in module.modules()):
|
||||
# raise TypeError(
|
||||
# "You are using the bitsandbytes precision plugin, but your model has no Linear layers. This plugin"
|
||||
# " won't work for your model."
|
||||
# )
|
||||
|
||||
# # convert modules if they haven't been converted already
|
||||
# if not any(isinstance(m, (bnb.nn.Linear8bitLt, bnb.nn.Linear4bit)) for m in module.modules()):
|
||||
# # this will not quantize the model but only replace the layer classes
|
||||
# _convert_layers(module, self._linear_cls, self.ignore_modules)
|
||||
|
||||
# # set the compute dtype if necessary
|
||||
# for m in module.modules():
|
||||
# if isinstance(m, bnb.nn.Linear4bit):
|
||||
# m.compute_dtype = self.dtype
|
||||
# m.compute_type_is_set = False
|
||||
# return module
|
||||
|
||||
|
||||
# def _quantize_on_load_hook(quantize_fn: Callable[[torch.Tensor], None], state_dict: OrderedDict, *_: Any) -> None:
|
||||
# # There is only one key that ends with `*.weight`, the other one is the bias
|
||||
# weight_key = next((name for name in state_dict if name.endswith("weight")), None)
|
||||
# if weight_key is None:
|
||||
# return
|
||||
# # Load the weight from the state dict and re-quantize it
|
||||
# weight = state_dict.pop(weight_key)
|
||||
# quantize_fn(weight)
|
||||
|
||||
|
||||
# def _ignore_missing_weights_hook(module: torch.nn.Module, incompatible_keys: _IncompatibleKeys) -> None:
|
||||
# # since we manually loaded the weight in the `_quantize_on_load_hook` hook, we need to avoid this missing key false
|
||||
# # positive
|
||||
# for key in reversed(incompatible_keys.missing_keys):
|
||||
# if key.endswith("weight"):
|
||||
# incompatible_keys.missing_keys.remove(key)
|
||||
|
||||
|
||||
def _convert_linear_layers(
|
||||
module: torch.nn.Module, linear_cls: Type, ignore_modules: Set[str], prefix: str = ""
|
||||
) -> None:
|
||||
for name, child in module.named_children():
|
||||
fullname = f"{prefix}.{name}" if prefix else name
|
||||
if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules):
|
||||
has_bias = child.bias is not None
|
||||
# since we are going to copy over the child's data, the device doesn't matter. I chose CPU
|
||||
# to avoid spiking CUDA memory even though initialization is slower
|
||||
# 4bit layers support quantizing from meta-device params so this is only relevant for 8-bit
|
||||
_Linear4bit = globals()["_Linear4bit"]
|
||||
device = torch.device("meta" if issubclass(linear_cls, _Linear4bit) else "cpu")
|
||||
replacement = linear_cls(
|
||||
child.in_features,
|
||||
child.out_features,
|
||||
bias=has_bias,
|
||||
device=device,
|
||||
)
|
||||
if has_bias:
|
||||
replacement.bias = _replace_param(replacement.bias, child.bias.data.clone())
|
||||
state = {"quant_state": replacement.weight.quant_state if issubclass(linear_cls, _Linear4bit) else None}
|
||||
replacement.weight = _replace_param(replacement.weight, child.weight.data.clone(), **state)
|
||||
module.__setattr__(name, replacement)
|
||||
else:
|
||||
_convert_linear_layers(child, linear_cls, ignore_modules, prefix=fullname)
|
||||
|
||||
|
||||
# def _replace_linear_layers(
|
||||
# model: torch.nn.Module,
|
||||
# linear_layer_type: Literal["Linear8bitLt", "Linear4bit"],
|
||||
# modules_to_not_convert: set[str],
|
||||
# current_key_name: str | None = None,
|
||||
# ):
|
||||
# has_been_replaced = False
|
||||
# for name, module in model.named_children():
|
||||
# if current_key_name is None:
|
||||
# current_key_name = []
|
||||
# current_key_name.append(name)
|
||||
# if isinstance(module, torch.nn.Linear) and name not in modules_to_not_convert:
|
||||
# # Check if the current key is not in the `modules_to_not_convert`
|
||||
# current_key_name_str = ".".join(current_key_name)
|
||||
# proceed = True
|
||||
# for key in modules_to_not_convert:
|
||||
# if (
|
||||
# (key in current_key_name_str) and (key + "." in current_key_name_str)
|
||||
# ) or key == current_key_name_str:
|
||||
# proceed = False
|
||||
# break
|
||||
# if proceed:
|
||||
# # Load bnb module with empty weight and replace ``nn.Linear` module
|
||||
# if bnb_quantization_config.load_in_8bit:
|
||||
# bnb_module = bnb.nn.Linear8bitLt(
|
||||
# module.in_features,
|
||||
# module.out_features,
|
||||
# module.bias is not None,
|
||||
# has_fp16_weights=False,
|
||||
# threshold=bnb_quantization_config.llm_int8_threshold,
|
||||
# )
|
||||
# elif bnb_quantization_config.load_in_4bit:
|
||||
# bnb_module = bnb.nn.Linear4bit(
|
||||
# module.in_features,
|
||||
# module.out_features,
|
||||
# module.bias is not None,
|
||||
# bnb_quantization_config.bnb_4bit_compute_dtype,
|
||||
# compress_statistics=bnb_quantization_config.bnb_4bit_use_double_quant,
|
||||
# quant_type=bnb_quantization_config.bnb_4bit_quant_type,
|
||||
# )
|
||||
# else:
|
||||
# raise ValueError("load_in_8bit and load_in_4bit can't be both False")
|
||||
# bnb_module.weight.data = module.weight.data
|
||||
# if module.bias is not None:
|
||||
# bnb_module.bias.data = module.bias.data
|
||||
# bnb_module.requires_grad_(False)
|
||||
# setattr(model, name, bnb_module)
|
||||
# has_been_replaced = True
|
||||
# if len(list(module.children())) > 0:
|
||||
# _, _has_been_replaced = _replace_with_bnb_layers(
|
||||
# module, bnb_quantization_config, modules_to_not_convert, current_key_name
|
||||
# )
|
||||
# has_been_replaced = has_been_replaced | _has_been_replaced
|
||||
# # Remove the last key for recursion
|
||||
# current_key_name.pop(-1)
|
||||
# return model, has_been_replaced
|
||||
129
invokeai/backend/load_flux_model.py
Normal file
129
invokeai/backend/load_flux_model.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from diffusers.models.model_loading_utils import load_state_dict
|
||||
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
||||
from diffusers.utils import (
|
||||
CONFIG_NAME,
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
SAFETENSORS_WEIGHTS_NAME,
|
||||
_get_checkpoint_shard_files,
|
||||
is_accelerate_available,
|
||||
)
|
||||
from optimum.quanto import qfloat8
|
||||
from optimum.quanto.models import QuantizedDiffusersModel
|
||||
from optimum.quanto.models.shared_dict import ShardedStateDict
|
||||
|
||||
from invokeai.backend.requantize import requantize
|
||||
|
||||
|
||||
class QuantizedFluxTransformer2DModel(QuantizedDiffusersModel):
|
||||
base_class = FluxTransformer2DModel
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike]):
|
||||
if cls.base_class is None:
|
||||
raise ValueError("The `base_class` attribute needs to be configured.")
|
||||
|
||||
if not is_accelerate_available():
|
||||
raise ValueError("Reloading a quantized diffusers model requires the accelerate library.")
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
if os.path.isdir(model_name_or_path):
|
||||
# Look for a quantization map
|
||||
qmap_path = os.path.join(model_name_or_path, cls._qmap_name())
|
||||
if not os.path.exists(qmap_path):
|
||||
raise ValueError(f"No quantization map found in {model_name_or_path}: is this a quantized model ?")
|
||||
|
||||
# Look for original model config file.
|
||||
model_config_path = os.path.join(model_name_or_path, CONFIG_NAME)
|
||||
if not os.path.exists(model_config_path):
|
||||
raise ValueError(f"{CONFIG_NAME} not found in {model_name_or_path}.")
|
||||
|
||||
with open(qmap_path, "r", encoding="utf-8") as f:
|
||||
qmap = json.load(f)
|
||||
|
||||
with open(model_config_path, "r", encoding="utf-8") as f:
|
||||
original_model_cls_name = json.load(f)["_class_name"]
|
||||
configured_cls_name = cls.base_class.__name__
|
||||
if configured_cls_name != original_model_cls_name:
|
||||
raise ValueError(
|
||||
f"Configured base class ({configured_cls_name}) differs from what was derived from the provided configuration ({original_model_cls_name})."
|
||||
)
|
||||
|
||||
# Create an empty model
|
||||
config = cls.base_class.load_config(model_name_or_path)
|
||||
with init_empty_weights():
|
||||
model = cls.base_class.from_config(config)
|
||||
|
||||
# Look for the index of a sharded checkpoint
|
||||
checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
|
||||
if os.path.exists(checkpoint_file):
|
||||
# Convert the checkpoint path to a list of shards
|
||||
_, sharded_metadata = _get_checkpoint_shard_files(model_name_or_path, checkpoint_file)
|
||||
# Create a mapping for the sharded safetensor files
|
||||
state_dict = ShardedStateDict(model_name_or_path, sharded_metadata["weight_map"])
|
||||
else:
|
||||
# Look for a single checkpoint file
|
||||
checkpoint_file = os.path.join(model_name_or_path, SAFETENSORS_WEIGHTS_NAME)
|
||||
if not os.path.exists(checkpoint_file):
|
||||
raise ValueError(f"No safetensor weights found in {model_name_or_path}.")
|
||||
# Get state_dict from model checkpoint
|
||||
state_dict = load_state_dict(checkpoint_file)
|
||||
|
||||
# Requantize and load quantized weights from state_dict
|
||||
requantize(model, state_dict=state_dict, quantization_map=qmap)
|
||||
model.eval()
|
||||
return cls(model)
|
||||
else:
|
||||
raise NotImplementedError("Reloading quantized models directly from the hub is not supported yet.")
|
||||
|
||||
|
||||
def load_flux_transformer(path: Path) -> FluxTransformer2DModel:
|
||||
# model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
|
||||
model_8bit_path = path / "quantized"
|
||||
if model_8bit_path.exists():
|
||||
# The quantized model exists, load it.
|
||||
# TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like
|
||||
# something that we should be able to make much faster.
|
||||
q_model = QuantizedFluxTransformer2DModel.from_pretrained(model_8bit_path)
|
||||
|
||||
# Access the underlying wrapped model.
|
||||
# We access the wrapped model, even though it is private, because it simplifies the type checking by
|
||||
# always returning a FluxTransformer2DModel from this function.
|
||||
model = q_model._wrapped
|
||||
else:
|
||||
# The quantized model does not exist yet, quantize and save it.
|
||||
# TODO(ryand): Loading in float16 and then quantizing seems to result in NaNs. In order to run this on
|
||||
# GPUs that don't support bfloat16, we would need to host the quantized model instead of generating it
|
||||
# here.
|
||||
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
|
||||
assert isinstance(model, FluxTransformer2DModel)
|
||||
|
||||
q_model = QuantizedFluxTransformer2DModel.quantize(model, weights=qfloat8)
|
||||
|
||||
model_8bit_path.mkdir(parents=True, exist_ok=True)
|
||||
q_model.save_pretrained(model_8bit_path)
|
||||
|
||||
# (See earlier comment about accessing the wrapped model.)
|
||||
model = q_model._wrapped
|
||||
|
||||
assert isinstance(model, FluxTransformer2DModel)
|
||||
return model
|
||||
|
||||
|
||||
def main():
|
||||
start = time.time()
|
||||
model = load_flux_transformer(
|
||||
Path("/data/invokeai/models/.download_cache/black-forest-labs_flux.1-schnell/FLUX.1-schnell/transformer/")
|
||||
)
|
||||
print(f"Time to load: {time.time() - start}s")
|
||||
print("hi")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
124
invokeai/backend/load_flux_model_bnb_llm_int8_old.py
Normal file
124
invokeai/backend/load_flux_model_bnb_llm_int8_old.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import accelerate
|
||||
import torch
|
||||
from accelerate.utils import BnbQuantizationConfig, load_and_quantize_model
|
||||
from accelerate.utils.bnb import get_keys_to_not_convert
|
||||
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from invokeai.backend.bnb import quantize_model_llm_int8
|
||||
|
||||
# Docs:
|
||||
# https://huggingface.co/docs/accelerate/usage_guides/quantization
|
||||
# https://huggingface.co/docs/bitsandbytes/v0.43.3/en/integrations#accelerate
|
||||
|
||||
|
||||
def get_parameter_device(parameter: torch.nn.Module):
|
||||
return next(parameter.parameters()).device
|
||||
|
||||
|
||||
# def quantize_model_llm_int8(model: torch.nn.Module, modules_to_not_convert: set[str], llm_int8_threshold: int = 6):
|
||||
# """Apply bitsandbytes LLM.8bit() quantization to the model."""
|
||||
# model_device = get_parameter_device(model)
|
||||
# if model_device.type != "meta":
|
||||
# # Note: This is not strictly required, but I can't think of a good reason to quantize a model that's not on the
|
||||
# # meta device, so we enforce it for now.
|
||||
# raise RuntimeError("The model should be on the meta device to apply LLM.8bit() quantization.")
|
||||
|
||||
# bnb_quantization_config = BnbQuantizationConfig(
|
||||
# load_in_8bit=True,
|
||||
# llm_int8_threshold=llm_int8_threshold,
|
||||
# )
|
||||
|
||||
# with accelerate.init_empty_weights():
|
||||
# model = replace_with_bnb_layers(model, bnb_quantization_config, modules_to_not_convert=modules_to_not_convert)
|
||||
|
||||
# return model
|
||||
|
||||
|
||||
def load_flux_transformer(path: Path) -> FluxTransformer2DModel:
|
||||
model_config = FluxTransformer2DModel.load_config(path, local_files_only=True)
|
||||
with accelerate.init_empty_weights():
|
||||
empty_model = FluxTransformer2DModel.from_config(model_config)
|
||||
assert isinstance(empty_model, FluxTransformer2DModel)
|
||||
|
||||
bnb_quantization_config = BnbQuantizationConfig(
|
||||
load_in_8bit=True,
|
||||
llm_int8_threshold=6,
|
||||
)
|
||||
|
||||
model_8bit_path = path / "bnb_llm_int8"
|
||||
if model_8bit_path.exists():
|
||||
# The quantized model already exists, load it and return it.
|
||||
# Note that the model loading code is the same when loading from quantized vs original weights. The only
|
||||
# difference is the weights_location.
|
||||
# model = load_and_quantize_model(
|
||||
# empty_model,
|
||||
# weights_location=model_8bit_path,
|
||||
# bnb_quantization_config=bnb_quantization_config,
|
||||
# # device_map="auto",
|
||||
# device_map={"": "cpu"},
|
||||
# )
|
||||
|
||||
# TODO: Handle the keys that were not quantized (get_keys_to_not_convert).
|
||||
model = quantize_model_llm_int8(empty_model, modules_to_not_convert=set())
|
||||
|
||||
# model = quantize_model_llm_int8(empty_model, set())
|
||||
|
||||
# Load sharded state dict.
|
||||
files = list(path.glob("*.safetensors"))
|
||||
state_dict = dict()
|
||||
for file in files:
|
||||
sd = load_file(file)
|
||||
state_dict.update(sd)
|
||||
|
||||
else:
|
||||
# The quantized model does not exist yet, quantize and save it.
|
||||
model = load_and_quantize_model(
|
||||
empty_model,
|
||||
weights_location=path,
|
||||
bnb_quantization_config=bnb_quantization_config,
|
||||
device_map="auto",
|
||||
)
|
||||
|
||||
keys_to_not_convert = get_keys_to_not_convert(empty_model) # TODO
|
||||
|
||||
model_8bit_path.mkdir(parents=True, exist_ok=True)
|
||||
accl = accelerate.Accelerator()
|
||||
accl.save_model(model, model_8bit_path)
|
||||
|
||||
# ---------------------
|
||||
|
||||
# model = quantize_model_llm_int8(empty_model, set())
|
||||
|
||||
# # Load sharded state dict.
|
||||
# files = list(path.glob("*.safetensors"))
|
||||
# state_dict = dict()
|
||||
# for file in files:
|
||||
# sd = load_file(file)
|
||||
# state_dict.update(sd)
|
||||
|
||||
# # Load the state dict into the model. The bitsandbytes layers know how to load from both quantized and
|
||||
# # non-quantized state dicts.
|
||||
# result = model.load_state_dict(state_dict, strict=True)
|
||||
# model = model.to("cuda")
|
||||
|
||||
# ---------------------
|
||||
|
||||
assert isinstance(model, FluxTransformer2DModel)
|
||||
return model
|
||||
|
||||
|
||||
def main():
|
||||
start = time.time()
|
||||
model = load_flux_transformer(
|
||||
Path("/data/invokeai/models/.download_cache/black-forest-labs_flux.1-schnell/FLUX.1-schnell/transformer/")
|
||||
)
|
||||
print(f"Time to load: {time.time() - start}s")
|
||||
print("hi")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -54,6 +54,7 @@ def filter_files(
|
||||
"lora_weights.safetensors",
|
||||
"weights.pb",
|
||||
"onnx_data",
|
||||
"spiece.model", # Added for `black-forest-labs/FLUX.1-schnell`.
|
||||
)
|
||||
):
|
||||
paths.append(file)
|
||||
@@ -62,7 +63,7 @@ def filter_files(
|
||||
# downloading random checkpoints that might also be in the repo. However there is no guarantee
|
||||
# that a checkpoint doesn't contain "model" in its name, and no guarantee that future diffusers models
|
||||
# will adhere to this naming convention, so this is an area to be careful of.
|
||||
elif re.search(r"model(\.[^.]+)?\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name):
|
||||
elif re.search(r"model.*\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name):
|
||||
paths.append(file)
|
||||
|
||||
# limit search to subfolder if requested
|
||||
@@ -97,7 +98,9 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
|
||||
if variant == ModelRepoVariant.Flax:
|
||||
result.add(path)
|
||||
|
||||
elif path.suffix in [".json", ".txt"]:
|
||||
# Note: '.model' was added to support:
|
||||
# https://huggingface.co/black-forest-labs/FLUX.1-schnell/blob/768d12a373ed5cc9ef9a9dea7504dc09fcc14842/tokenizer_2/spiece.model
|
||||
elif path.suffix in [".json", ".txt", ".model"]:
|
||||
result.add(path)
|
||||
|
||||
elif variant in [
|
||||
@@ -140,6 +143,23 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
|
||||
continue
|
||||
|
||||
for candidate_list in subfolder_weights.values():
|
||||
# Check if at least one of the files has the explicit fp16 variant.
|
||||
at_least_one_fp16 = False
|
||||
for candidate in candidate_list:
|
||||
if len(candidate.path.suffixes) == 2 and candidate.path.suffixes[0] == ".fp16":
|
||||
at_least_one_fp16 = True
|
||||
break
|
||||
|
||||
if not at_least_one_fp16:
|
||||
# If none of the candidates in this candidate_list have the explicit fp16 variant label, then this
|
||||
# candidate_list probably doesn't adhere to the variant naming convention that we expected. In this case,
|
||||
# we'll simply keep all the candidates. An example of a model that hits this case is
|
||||
# `black-forest-labs/FLUX.1-schnell` (as of commit 012d2fd).
|
||||
for candidate in candidate_list:
|
||||
result.add(candidate.path)
|
||||
|
||||
# The candidate_list seems to have the expected variant naming convention. We'll select the highest scoring
|
||||
# candidate.
|
||||
highest_score_candidate = max(candidate_list, key=lambda candidate: candidate.score)
|
||||
if highest_score_candidate:
|
||||
result.add(highest_score_candidate.path)
|
||||
|
||||
125
invokeai/backend/quantization/bnb_llm_int8.py
Normal file
125
invokeai/backend/quantization/bnb_llm_int8.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
|
||||
# This file contains utils for working with models that use bitsandbytes LLM.int8() quantization.
|
||||
# The utils in this file are partially inspired by:
|
||||
# https://github.com/Lightning-AI/pytorch-lightning/blob/1551a16b94f5234a4a78801098f64d0732ef5cb5/src/lightning/fabric/plugins/precision/bitsandbytes.py
|
||||
|
||||
|
||||
# NOTE(ryand): All of the custom state_dict manipulation logic in this file is pretty hacky. This could be made much
|
||||
# cleaner by re-implementing bnb.nn.Linear8bitLt with proper use of buffers and less magic. But, for now, we try to
|
||||
# stick close to the bitsandbytes classes to make interoperability easier with other models that might use bitsandbytes.
|
||||
|
||||
|
||||
class InvokeInt8Params(bnb.nn.Int8Params):
|
||||
"""We override cuda() to avoid re-quantizing the weights in the following cases:
|
||||
- We loaded quantized weights from a state_dict on the cpu, and then moved the model to the gpu.
|
||||
- We are moving the model back-and-forth between the cpu and gpu.
|
||||
"""
|
||||
|
||||
def cuda(self, device):
|
||||
if self.has_fp16_weights:
|
||||
return super().cuda(device)
|
||||
elif self.CB is not None and self.SCB is not None:
|
||||
self.data = self.data.cuda()
|
||||
self.CB = self.CB.cuda()
|
||||
self.SCB = self.SCB.cuda()
|
||||
else:
|
||||
# we store the 8-bit rows-major weight
|
||||
# we convert this weight to the turning/ampere weight during the first inference pass
|
||||
B = self.data.contiguous().half().cuda(device)
|
||||
CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
|
||||
del CBt
|
||||
del SCBt
|
||||
self.data = CB
|
||||
self.CB = CB
|
||||
self.SCB = SCB
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class InvokeLinear8bitLt(bnb.nn.Linear8bitLt):
|
||||
def _load_from_state_dict(
|
||||
self,
|
||||
state_dict: dict[str, torch.Tensor],
|
||||
prefix: str,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
):
|
||||
weight = state_dict.pop(prefix + "weight")
|
||||
bias = state_dict.pop(prefix + "bias", None)
|
||||
|
||||
# See `bnb.nn.Linear8bitLt._save_to_state_dict()` for the serialization logic of SCB and weight_format.
|
||||
scb = state_dict.pop(prefix + "SCB", None)
|
||||
# weight_format is unused, but we pop it so we can validate that there are no unexpected keys.
|
||||
_weight_format = state_dict.pop(prefix + "weight_format", None)
|
||||
|
||||
# TODO(ryand): Technically, we should be using `strict`, `missing_keys`, `unexpected_keys`, and `error_msgs`
|
||||
# rather than raising an exception to correctly implement this API.
|
||||
assert len(state_dict) == 0
|
||||
|
||||
if scb is not None:
|
||||
# We are loading a pre-quantized state dict.
|
||||
self.weight = InvokeInt8Params(
|
||||
data=weight,
|
||||
requires_grad=self.weight.requires_grad,
|
||||
has_fp16_weights=False,
|
||||
# Note: After quantization, CB is the same as weight.
|
||||
CB=weight,
|
||||
SCB=scb,
|
||||
)
|
||||
self.bias = bias if bias is None else torch.nn.Parameter(bias)
|
||||
else:
|
||||
# We are loading a non-quantized state dict.
|
||||
|
||||
# We could simply call the `super()._load_from_state_dict()` method here, but then we wouldn't be able to
|
||||
# load from a state_dict into a model on the "meta" device. Attempting to load into a model on the "meta"
|
||||
# device requires setting `assign=True`, doing this with the default `super()._load_from_state_dict()`
|
||||
# implementation causes `Params4Bit` to be replaced by a `torch.nn.Parameter`. By initializing a new
|
||||
# `Params4bit` object, we work around this issue. It's a bit hacky, but it gets the job done.
|
||||
self.weight = InvokeInt8Params(
|
||||
data=weight,
|
||||
requires_grad=self.weight.requires_grad,
|
||||
has_fp16_weights=False,
|
||||
CB=None,
|
||||
SCB=None,
|
||||
)
|
||||
self.bias = bias if bias is None else torch.nn.Parameter(bias)
|
||||
|
||||
|
||||
def _convert_linear_layers_to_llm_8bit(
|
||||
module: torch.nn.Module, ignore_modules: set[str], outlier_threshold: float, prefix: str = ""
|
||||
) -> None:
|
||||
"""Convert all linear layers in the module to bnb.nn.Linear8bitLt layers."""
|
||||
for name, child in module.named_children():
|
||||
fullname = f"{prefix}.{name}" if prefix else name
|
||||
if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules):
|
||||
has_bias = child.bias is not None
|
||||
replacement = InvokeLinear8bitLt(
|
||||
child.in_features,
|
||||
child.out_features,
|
||||
bias=has_bias,
|
||||
has_fp16_weights=False,
|
||||
threshold=outlier_threshold,
|
||||
)
|
||||
replacement.weight.data = child.weight.data
|
||||
if has_bias:
|
||||
replacement.bias.data = child.bias.data
|
||||
replacement.requires_grad_(False)
|
||||
module.__setattr__(name, replacement)
|
||||
else:
|
||||
_convert_linear_layers_to_llm_8bit(
|
||||
child, ignore_modules, outlier_threshold=outlier_threshold, prefix=fullname
|
||||
)
|
||||
|
||||
|
||||
def quantize_model_llm_int8(model: torch.nn.Module, modules_to_not_convert: set[str], outlier_threshold: float = 6.0):
|
||||
"""Apply bitsandbytes LLM.8bit() quantization to the model."""
|
||||
_convert_linear_layers_to_llm_8bit(
|
||||
module=model, ignore_modules=modules_to_not_convert, outlier_threshold=outlier_threshold
|
||||
)
|
||||
|
||||
return model
|
||||
156
invokeai/backend/quantization/bnb_nf4.py
Normal file
156
invokeai/backend/quantization/bnb_nf4.py
Normal file
@@ -0,0 +1,156 @@
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
|
||||
# This file contains utils for working with models that use bitsandbytes NF4 quantization.
|
||||
# The utils in this file are partially inspired by:
|
||||
# https://github.com/Lightning-AI/pytorch-lightning/blob/1551a16b94f5234a4a78801098f64d0732ef5cb5/src/lightning/fabric/plugins/precision/bitsandbytes.py
|
||||
|
||||
# NOTE(ryand): All of the custom state_dict manipulation logic in this file is pretty hacky. This could be made much
|
||||
# cleaner by re-implementing bnb.nn.LinearNF4 with proper use of buffers and less magic. But, for now, we try to stick
|
||||
# close to the bitsandbytes classes to make interoperability easier with other models that might use bitsandbytes.
|
||||
|
||||
|
||||
class InvokeLinearNF4(bnb.nn.LinearNF4):
|
||||
"""A class that extends `bnb.nn.LinearNF4` to add the following functionality:
|
||||
- Ability to load Linear NF4 layers from a pre-quantized state_dict.
|
||||
- Ability to load Linear NF4 layers from a state_dict when the model is on the "meta" device.
|
||||
"""
|
||||
|
||||
def _load_from_state_dict(
|
||||
self,
|
||||
state_dict: dict[str, torch.Tensor],
|
||||
prefix: str,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
):
|
||||
"""This method is based on the logic in the bitsandbytes serialization unit tests for `Linear4bit`:
|
||||
https://github.com/bitsandbytes-foundation/bitsandbytes/blob/6d714a5cce3db5bd7f577bc447becc7a92d5ccc7/tests/test_linear4bit.py#L52-L71
|
||||
"""
|
||||
weight = state_dict.pop(prefix + "weight")
|
||||
bias = state_dict.pop(prefix + "bias", None)
|
||||
# We expect the remaining keys to be quant_state keys.
|
||||
quant_state_sd = state_dict
|
||||
|
||||
# During serialization, the quant_state is stored as subkeys of "weight." (See
|
||||
# `bnb.nn.LinearNF4._save_to_state_dict()`). We validate that they at least have the correct prefix.
|
||||
# TODO(ryand): Technically, we should be using `strict`, `missing_keys`, `unexpected_keys`, and `error_msgs`
|
||||
# rather than raising an exception to correctly implement this API.
|
||||
assert all(k.startswith(prefix + "weight.") for k in quant_state_sd.keys())
|
||||
|
||||
if len(quant_state_sd) > 0:
|
||||
# We are loading a pre-quantized state dict.
|
||||
self.weight = bnb.nn.Params4bit.from_prequantized(
|
||||
data=weight, quantized_stats=quant_state_sd, device=weight.device
|
||||
)
|
||||
self.bias = bias if bias is None else torch.nn.Parameter(bias, requires_grad=False)
|
||||
else:
|
||||
# We are loading a non-quantized state dict.
|
||||
|
||||
# We could simply call the `super()._load_from_state_dict()` method here, but then we wouldn't be able to
|
||||
# load from a state_dict into a model on the "meta" device. Attempting to load into a model on the "meta"
|
||||
# device requires setting `assign=True`, doing this with the default `super()._load_from_state_dict()`
|
||||
# implementation causes `Params4Bit` to be replaced by a `torch.nn.Parameter`. By initializing a new
|
||||
# `Params4bit` object, we work around this issue. It's a bit hacky, but it gets the job done.
|
||||
self.weight = bnb.nn.Params4bit(
|
||||
data=weight,
|
||||
requires_grad=self.weight.requires_grad,
|
||||
compress_statistics=self.weight.compress_statistics,
|
||||
quant_type=self.weight.quant_type,
|
||||
quant_storage=self.weight.quant_storage,
|
||||
module=self,
|
||||
)
|
||||
self.bias = bias if bias is None else torch.nn.Parameter(bias)
|
||||
|
||||
|
||||
def _replace_param(
|
||||
param: torch.nn.Parameter | bnb.nn.Params4bit,
|
||||
data: torch.Tensor,
|
||||
) -> torch.nn.Parameter:
|
||||
"""A helper function to replace the data of a model parameter with new data in a way that allows replacing params on
|
||||
the "meta" device.
|
||||
|
||||
Supports both `torch.nn.Parameter` and `bnb.nn.Params4bit` parameters.
|
||||
"""
|
||||
if param.device.type == "meta":
|
||||
# Doing `param.data = data` raises a RuntimeError if param.data was on the "meta" device, so we need to
|
||||
# re-create the param instead of overwriting the data.
|
||||
if isinstance(param, bnb.nn.Params4bit):
|
||||
return bnb.nn.Params4bit(
|
||||
data,
|
||||
requires_grad=data.requires_grad,
|
||||
quant_state=param.quant_state,
|
||||
compress_statistics=param.compress_statistics,
|
||||
quant_type=param.quant_type,
|
||||
)
|
||||
return torch.nn.Parameter(data, requires_grad=data.requires_grad)
|
||||
|
||||
param.data = data
|
||||
return param
|
||||
|
||||
|
||||
def _convert_linear_layers_to_nf4(
|
||||
module: torch.nn.Module,
|
||||
ignore_modules: set[str],
|
||||
compute_dtype: torch.dtype,
|
||||
compress_statistics: bool = False,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
"""Convert all linear layers in the model to NF4 quantized linear layers.
|
||||
|
||||
Args:
|
||||
module: All linear layers in this module will be converted.
|
||||
ignore_modules: A set of module prefixes to ignore when converting linear layers.
|
||||
compute_dtype: The dtype to use for computation in the quantized linear layers.
|
||||
compress_statistics: Whether to enable nested quantization (aka double quantization) where the quantization
|
||||
constants from the first quantization are quantized again.
|
||||
prefix: The prefix of the current module in the model. Used to call this function recursively.
|
||||
"""
|
||||
for name, child in module.named_children():
|
||||
fullname = f"{prefix}.{name}" if prefix else name
|
||||
if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules):
|
||||
has_bias = child.bias is not None
|
||||
replacement = InvokeLinearNF4(
|
||||
child.in_features,
|
||||
child.out_features,
|
||||
bias=has_bias,
|
||||
compute_dtype=torch.float16,
|
||||
compress_statistics=compress_statistics,
|
||||
)
|
||||
if has_bias:
|
||||
replacement.bias = _replace_param(replacement.bias, child.bias.data)
|
||||
replacement.weight = _replace_param(replacement.weight, child.weight.data)
|
||||
replacement.requires_grad_(False)
|
||||
module.__setattr__(name, replacement)
|
||||
else:
|
||||
_convert_linear_layers_to_nf4(child, ignore_modules, compute_dtype=compute_dtype, prefix=fullname)
|
||||
|
||||
|
||||
def quantize_model_nf4(model: torch.nn.Module, modules_to_not_convert: set[str], compute_dtype: torch.dtype):
|
||||
"""Apply bitsandbytes nf4 quantization to the model.
|
||||
|
||||
You likely want to call this function inside a `accelerate.init_empty_weights()` context.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
# Initialize the model from a config on the meta device.
|
||||
with accelerate.init_empty_weights():
|
||||
model = ModelClass.from_config(...)
|
||||
|
||||
# Add NF4 quantization linear layers to the model - still on the meta device.
|
||||
with accelerate.init_empty_weights():
|
||||
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.float16)
|
||||
|
||||
# Load a state_dict into the model. (Could be either a prequantized or non-quantized state_dict.)
|
||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||
|
||||
# Move the model to the "cuda" device. If the model was non-quantized, this is where the weight quantization takes
|
||||
# place.
|
||||
model.to("cuda")
|
||||
```
|
||||
"""
|
||||
_convert_linear_layers_to_nf4(module=model, ignore_modules=modules_to_not_convert, compute_dtype=compute_dtype)
|
||||
|
||||
return model
|
||||
@@ -0,0 +1,77 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
from diffusers.models.model_loading_utils import load_state_dict
|
||||
from diffusers.utils import (
|
||||
CONFIG_NAME,
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
SAFETENSORS_WEIGHTS_NAME,
|
||||
_get_checkpoint_shard_files,
|
||||
is_accelerate_available,
|
||||
)
|
||||
from optimum.quanto.models import QuantizedDiffusersModel
|
||||
from optimum.quanto.models.shared_dict import ShardedStateDict
|
||||
|
||||
from invokeai.backend.requantize import requantize
|
||||
|
||||
|
||||
class FastQuantizedDiffusersModel(QuantizedDiffusersModel):
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike]):
|
||||
"""We override the `from_pretrained()` method in order to use our custom `requantize()` implementation."""
|
||||
if cls.base_class is None:
|
||||
raise ValueError("The `base_class` attribute needs to be configured.")
|
||||
|
||||
if not is_accelerate_available():
|
||||
raise ValueError("Reloading a quantized diffusers model requires the accelerate library.")
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
if os.path.isdir(model_name_or_path):
|
||||
# Look for a quantization map
|
||||
qmap_path = os.path.join(model_name_or_path, cls._qmap_name())
|
||||
if not os.path.exists(qmap_path):
|
||||
raise ValueError(f"No quantization map found in {model_name_or_path}: is this a quantized model ?")
|
||||
|
||||
# Look for original model config file.
|
||||
model_config_path = os.path.join(model_name_or_path, CONFIG_NAME)
|
||||
if not os.path.exists(model_config_path):
|
||||
raise ValueError(f"{CONFIG_NAME} not found in {model_name_or_path}.")
|
||||
|
||||
with open(qmap_path, "r", encoding="utf-8") as f:
|
||||
qmap = json.load(f)
|
||||
|
||||
with open(model_config_path, "r", encoding="utf-8") as f:
|
||||
original_model_cls_name = json.load(f)["_class_name"]
|
||||
configured_cls_name = cls.base_class.__name__
|
||||
if configured_cls_name != original_model_cls_name:
|
||||
raise ValueError(
|
||||
f"Configured base class ({configured_cls_name}) differs from what was derived from the provided configuration ({original_model_cls_name})."
|
||||
)
|
||||
|
||||
# Create an empty model
|
||||
config = cls.base_class.load_config(model_name_or_path)
|
||||
with init_empty_weights():
|
||||
model = cls.base_class.from_config(config)
|
||||
|
||||
# Look for the index of a sharded checkpoint
|
||||
checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
|
||||
if os.path.exists(checkpoint_file):
|
||||
# Convert the checkpoint path to a list of shards
|
||||
_, sharded_metadata = _get_checkpoint_shard_files(model_name_or_path, checkpoint_file)
|
||||
# Create a mapping for the sharded safetensor files
|
||||
state_dict = ShardedStateDict(model_name_or_path, sharded_metadata["weight_map"])
|
||||
else:
|
||||
# Look for a single checkpoint file
|
||||
checkpoint_file = os.path.join(model_name_or_path, SAFETENSORS_WEIGHTS_NAME)
|
||||
if not os.path.exists(checkpoint_file):
|
||||
raise ValueError(f"No safetensor weights found in {model_name_or_path}.")
|
||||
# Get state_dict from model checkpoint
|
||||
state_dict = load_state_dict(checkpoint_file)
|
||||
|
||||
# Requantize and load quantized weights from state_dict
|
||||
requantize(model, state_dict=state_dict, quantization_map=qmap)
|
||||
model.eval()
|
||||
return cls(model)
|
||||
else:
|
||||
raise NotImplementedError("Reloading quantized models directly from the hub is not supported yet.")
|
||||
@@ -0,0 +1,61 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
from optimum.quanto.models import QuantizedTransformersModel
|
||||
from optimum.quanto.models.shared_dict import ShardedStateDict
|
||||
from transformers import AutoConfig
|
||||
from transformers.modeling_utils import get_checkpoint_shard_files, load_state_dict
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, is_accelerate_available
|
||||
|
||||
from invokeai.backend.requantize import requantize
|
||||
|
||||
|
||||
class FastQuantizedTransformersModel(QuantizedTransformersModel):
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike]):
|
||||
"""We override the `from_pretrained()` method in order to use our custom `requantize()` implementation."""
|
||||
if cls.auto_class is None:
|
||||
raise ValueError(
|
||||
"Quantized models cannot be reloaded using {cls}: use a specialized quantized class such as QuantizedModelForCausalLM instead."
|
||||
)
|
||||
if not is_accelerate_available():
|
||||
raise ValueError("Reloading a quantized transformers model requires the accelerate library.")
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
if os.path.isdir(model_name_or_path):
|
||||
# Look for a quantization map
|
||||
qmap_path = os.path.join(model_name_or_path, cls._qmap_name())
|
||||
if not os.path.exists(qmap_path):
|
||||
raise ValueError(f"No quantization map found in {model_name_or_path}: is this a quantized model ?")
|
||||
with open(qmap_path, "r", encoding="utf-8") as f:
|
||||
qmap = json.load(f)
|
||||
# Create an empty model
|
||||
config = AutoConfig.from_pretrained(model_name_or_path)
|
||||
with init_empty_weights():
|
||||
model = cls.auto_class.from_config(config)
|
||||
# Look for the index of a sharded checkpoint
|
||||
checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
|
||||
if os.path.exists(checkpoint_file):
|
||||
# Convert the checkpoint path to a list of shards
|
||||
checkpoint_file, sharded_metadata = get_checkpoint_shard_files(model_name_or_path, checkpoint_file)
|
||||
# Create a mapping for the sharded safetensor files
|
||||
state_dict = ShardedStateDict(model_name_or_path, sharded_metadata["weight_map"])
|
||||
else:
|
||||
# Look for a single checkpoint file
|
||||
checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_NAME)
|
||||
if not os.path.exists(checkpoint_file):
|
||||
raise ValueError(f"No safetensor weights found in {model_name_or_path}.")
|
||||
# Get state_dict from model checkpoint
|
||||
state_dict = load_state_dict(checkpoint_file)
|
||||
# Requantize and load quantized weights from state_dict
|
||||
requantize(model, state_dict=state_dict, quantization_map=qmap)
|
||||
if getattr(model.config, "tie_word_embeddings", True):
|
||||
# Tie output weight embeddings to input weight embeddings
|
||||
# Note that if they were quantized they would NOT be tied
|
||||
model.tie_weights()
|
||||
# Set model in evaluation mode as it is done in transformers
|
||||
model.eval()
|
||||
return cls(model)
|
||||
else:
|
||||
raise NotImplementedError("Reloading quantized models directly from the hub is not supported yet.")
|
||||
@@ -0,0 +1,89 @@
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
|
||||
import accelerate
|
||||
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
|
||||
|
||||
|
||||
@contextmanager
|
||||
def log_time(name: str):
|
||||
"""Helper context manager to log the time taken by a block of code."""
|
||||
start = time.time()
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
end = time.time()
|
||||
print(f"'{name}' took {end - start:.4f} secs")
|
||||
|
||||
|
||||
def main():
|
||||
# Load the FLUX transformer model onto the meta device.
|
||||
model_path = Path(
|
||||
"/data/invokeai/models/.download_cache/black-forest-labs_flux.1-schnell/FLUX.1-schnell/transformer/"
|
||||
)
|
||||
|
||||
with log_time("Initialize FLUX transformer on meta device"):
|
||||
model_config = FluxTransformer2DModel.load_config(model_path, local_files_only=True)
|
||||
with accelerate.init_empty_weights():
|
||||
empty_model = FluxTransformer2DModel.from_config(model_config)
|
||||
assert isinstance(empty_model, FluxTransformer2DModel)
|
||||
|
||||
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
|
||||
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.
|
||||
modules_to_not_convert: set[str] = set()
|
||||
|
||||
model_int8_path = model_path / "bnb_llm_int8"
|
||||
if model_int8_path.exists():
|
||||
# The quantized model already exists, load it and return it.
|
||||
print(f"A pre-quantized model already exists at '{model_int8_path}'. Attempting to load it...")
|
||||
|
||||
# Replace the linear layers with LLM.int8() quantized linear layers (still on the meta device).
|
||||
with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights():
|
||||
model = quantize_model_llm_int8(empty_model, modules_to_not_convert=modules_to_not_convert)
|
||||
|
||||
with log_time("Load state dict into model"):
|
||||
sd = load_file(model_int8_path / "model.safetensors")
|
||||
model.load_state_dict(sd, strict=True, assign=True)
|
||||
|
||||
with log_time("Move model to cuda"):
|
||||
model = model.to("cuda")
|
||||
|
||||
print(f"Successfully loaded pre-quantized model from '{model_int8_path}'.")
|
||||
|
||||
else:
|
||||
# The quantized model does not exist, quantize the model and save it.
|
||||
print(f"No pre-quantized model found at '{model_int8_path}'. Quantizing the model...")
|
||||
|
||||
with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights():
|
||||
model = quantize_model_llm_int8(empty_model, modules_to_not_convert=modules_to_not_convert)
|
||||
|
||||
with log_time("Load state dict into model"):
|
||||
# Load sharded state dict.
|
||||
files = list(model_path.glob("*.safetensors"))
|
||||
state_dict = dict()
|
||||
for file in files:
|
||||
sd = load_file(file)
|
||||
state_dict.update(sd)
|
||||
|
||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||
|
||||
with log_time("Move model to cuda and quantize"):
|
||||
model = model.to("cuda")
|
||||
|
||||
with log_time("Save quantized model"):
|
||||
model_int8_path.mkdir(parents=True, exist_ok=True)
|
||||
output_path = model_int8_path / "model.safetensors"
|
||||
save_file(model.state_dict(), output_path)
|
||||
|
||||
print(f"Successfully quantized and saved model to '{output_path}'.")
|
||||
|
||||
assert isinstance(model, FluxTransformer2DModel)
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
91
invokeai/backend/quantization/load_flux_model_bnb_nf4.py
Normal file
91
invokeai/backend/quantization/load_flux_model_bnb_nf4.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
|
||||
import accelerate
|
||||
import torch
|
||||
from flux.model import Flux
|
||||
from flux.util import configs as flux_configs
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
|
||||
|
||||
|
||||
@contextmanager
|
||||
def log_time(name: str):
|
||||
"""Helper context manager to log the time taken by a block of code."""
|
||||
start = time.time()
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
end = time.time()
|
||||
print(f"'{name}' took {end - start:.4f} secs")
|
||||
|
||||
|
||||
def main():
|
||||
model_path = Path(
|
||||
"/data/invokeai/models/.download_cache/https__huggingface.co_black-forest-labs_flux.1-schnell_resolve_main_flux1-schnell.safetensors/flux1-schnell.safetensors"
|
||||
)
|
||||
|
||||
# inference_dtype = torch.bfloat16
|
||||
with log_time("Intialize FLUX transformer on meta device"):
|
||||
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
|
||||
params = flux_configs["flux-schnell"].params
|
||||
|
||||
# Initialize the model on the "meta" device.
|
||||
with accelerate.init_empty_weights():
|
||||
model = Flux(params)
|
||||
|
||||
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
|
||||
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.
|
||||
modules_to_not_convert: set[str] = set()
|
||||
|
||||
model_nf4_path = model_path.parent / "bnb_nf4.safetensors"
|
||||
if model_nf4_path.exists():
|
||||
# The quantized model already exists, load it and return it.
|
||||
print(f"A pre-quantized model already exists at '{model_nf4_path}'. Attempting to load it...")
|
||||
|
||||
# Replace the linear layers with NF4 quantized linear layers (still on the meta device).
|
||||
with log_time("Replace linear layers with NF4 layers"), accelerate.init_empty_weights():
|
||||
model = quantize_model_nf4(
|
||||
model, modules_to_not_convert=modules_to_not_convert, compute_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
with log_time("Load state dict into model"):
|
||||
state_dict = load_file(model_nf4_path)
|
||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||
|
||||
with log_time("Move model to cuda"):
|
||||
model = model.to("cuda")
|
||||
|
||||
print(f"Successfully loaded pre-quantized model from '{model_nf4_path}'.")
|
||||
|
||||
else:
|
||||
# The quantized model does not exist, quantize the model and save it.
|
||||
print(f"No pre-quantized model found at '{model_nf4_path}'. Quantizing the model...")
|
||||
|
||||
with log_time("Replace linear layers with NF4 layers"), accelerate.init_empty_weights():
|
||||
model = quantize_model_nf4(
|
||||
model, modules_to_not_convert=modules_to_not_convert, compute_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
with log_time("Load state dict into model"):
|
||||
state_dict = load_file(model_path)
|
||||
# TODO(ryand): Cast the state_dict to the appropriate dtype?
|
||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||
|
||||
with log_time("Move model to cuda and quantize"):
|
||||
model = model.to("cuda")
|
||||
|
||||
with log_time("Save quantized model"):
|
||||
model_nf4_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
save_file(model.state_dict(), model_nf4_path)
|
||||
|
||||
print(f"Successfully quantized and saved model to '{model_nf4_path}'.")
|
||||
|
||||
assert isinstance(model, Flux)
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
53
invokeai/backend/requantize.py
Normal file
53
invokeai/backend/requantize.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
from optimum.quanto.quantize import _quantize_submodule
|
||||
|
||||
# def custom_freeze(model: torch.nn.Module):
|
||||
# for name, m in model.named_modules():
|
||||
# if isinstance(m, QModuleMixin):
|
||||
# m.weight =
|
||||
# m.freeze()
|
||||
|
||||
|
||||
def requantize(
|
||||
model: torch.nn.Module,
|
||||
state_dict: Dict[str, Any],
|
||||
quantization_map: Dict[str, Dict[str, str]],
|
||||
device: torch.device = None,
|
||||
):
|
||||
if device is None:
|
||||
device = next(model.parameters()).device
|
||||
if device.type == "meta":
|
||||
device = torch.device("cpu")
|
||||
|
||||
# Quantize the model with parameters from the quantization map
|
||||
for name, m in model.named_modules():
|
||||
qconfig = quantization_map.get(name, None)
|
||||
if qconfig is not None:
|
||||
weights = qconfig["weights"]
|
||||
if weights == "none":
|
||||
weights = None
|
||||
activations = qconfig["activations"]
|
||||
if activations == "none":
|
||||
activations = None
|
||||
_quantize_submodule(model, name, m, weights=weights, activations=activations)
|
||||
|
||||
# Move model parameters and buffers to CPU before materializing quantized weights
|
||||
for name, m in model.named_modules():
|
||||
|
||||
def move_tensor(t, device):
|
||||
if t.device.type == "meta":
|
||||
return torch.empty_like(t, device=device)
|
||||
return t.to(device)
|
||||
|
||||
for name, param in m.named_parameters(recurse=False):
|
||||
setattr(m, name, torch.nn.Parameter(move_tensor(param, "cpu")))
|
||||
for name, param in m.named_buffers(recurse=False):
|
||||
setattr(m, name, move_tensor(param, "cpu"))
|
||||
# Freeze model and move to target device
|
||||
# freeze(model)
|
||||
# model.to(device)
|
||||
|
||||
# Load the quantized model weights
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
@@ -25,11 +25,6 @@ class BasicConditioningInfo:
|
||||
return self
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConditioningFieldData:
|
||||
conditionings: List[BasicConditioningInfo]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SDXLConditioningInfo(BasicConditioningInfo):
|
||||
"""SDXL text conditioning information produced by Compel."""
|
||||
@@ -43,6 +38,17 @@ class SDXLConditioningInfo(BasicConditioningInfo):
|
||||
return super().to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FLUXConditioningInfo:
|
||||
clip_embeds: torch.Tensor
|
||||
t5_embeds: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConditioningFieldData:
|
||||
conditionings: List[BasicConditioningInfo] | List[SDXLConditioningInfo] | List[FLUXConditioningInfo]
|
||||
|
||||
|
||||
@dataclass
|
||||
class IPAdapterConditioningInfo:
|
||||
cond_image_prompt_embeds: torch.Tensor
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
import diffusers
|
||||
import torch
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.loaders import FromOriginalControlNetMixin
|
||||
from diffusers.loaders.single_file_model import FromOriginalModelMixin
|
||||
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
|
||||
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
|
||||
from diffusers.models.embeddings import (
|
||||
@@ -32,7 +32,7 @@ from invokeai.backend.util.logging import InvokeAILogger
|
||||
logger = InvokeAILogger.get_logger(__name__)
|
||||
|
||||
|
||||
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
|
||||
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
"""
|
||||
A ControlNet model.
|
||||
|
||||
|
||||
@@ -33,31 +33,37 @@ classifiers = [
|
||||
]
|
||||
dependencies = [
|
||||
# Core generation dependencies, pinned for reproducible builds.
|
||||
"accelerate==0.30.1",
|
||||
"accelerate==0.33.0",
|
||||
"bitsandbytes==0.43.3",
|
||||
"clip_anytorch==2.6.0", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
||||
"compel==2.0.2",
|
||||
"controlnet-aux==0.0.7",
|
||||
"diffusers[torch]==0.27.2",
|
||||
# TODO(ryand): Bump this once the next diffusers release is ready.
|
||||
"diffusers[torch] @ git+https://github.com/huggingface/diffusers.git@4c6152c2fb0ade468aadb417102605a07a8635d3",
|
||||
"flux @ git+https://github.com/black-forest-labs/flux.git@c23ae247225daba30fbd56058d247cc1b1fc20a3",
|
||||
"invisible-watermark==0.2.0", # needed to install SDXL base and refiner using their repo_ids
|
||||
"mediapipe==0.10.7", # needed for "mediapipeface" controlnet model
|
||||
"numpy==1.26.4", # >1.24.0 is needed to use the 'strict' argument to np.testing.assert_array_equal()
|
||||
"onnx==1.15.0",
|
||||
"onnxruntime==1.16.3",
|
||||
"opencv-python==4.9.0.80",
|
||||
"optimum-quanto==0.2.4",
|
||||
"pytorch-lightning==2.1.3",
|
||||
"safetensors==0.4.3",
|
||||
# sentencepiece is required to load T5TokenizerFast (used by FLUX).
|
||||
"sentencepiece==0.2.0",
|
||||
"spandrel==0.3.4",
|
||||
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
|
||||
"torch==2.2.2",
|
||||
"torch==2.4.0",
|
||||
"torchmetrics==0.11.4",
|
||||
"torchsde==0.2.6",
|
||||
"torchvision==0.17.2",
|
||||
"torchvision==0.19.0",
|
||||
"transformers==4.41.1",
|
||||
|
||||
# Core application dependencies, pinned for reproducible builds.
|
||||
"fastapi-events==0.11.1",
|
||||
"fastapi==0.111.0",
|
||||
"huggingface-hub==0.23.1",
|
||||
"huggingface-hub==0.24.5",
|
||||
"pydantic-settings==2.2.1",
|
||||
"pydantic==2.7.2",
|
||||
"python-socketio==5.11.1",
|
||||
|
||||
@@ -326,3 +326,80 @@ def test_select_multiple_weights(
|
||||
) -> None:
|
||||
filtered_files = filter_files(sd15_test_files, variant)
|
||||
assert set(filtered_files) == {Path(f) for f in expected_files}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def flux_schnell_test_files() -> list[Path]:
|
||||
return [
|
||||
Path(f)
|
||||
for f in [
|
||||
"FLUX.1-schnell/.gitattributes",
|
||||
"FLUX.1-schnell/README.md",
|
||||
"FLUX.1-schnell/ae.safetensors",
|
||||
"FLUX.1-schnell/flux1-schnell.safetensors",
|
||||
"FLUX.1-schnell/model_index.json",
|
||||
"FLUX.1-schnell/scheduler/scheduler_config.json",
|
||||
"FLUX.1-schnell/schnell_grid.jpeg",
|
||||
"FLUX.1-schnell/text_encoder/config.json",
|
||||
"FLUX.1-schnell/text_encoder/model.safetensors",
|
||||
"FLUX.1-schnell/text_encoder_2/config.json",
|
||||
"FLUX.1-schnell/text_encoder_2/model-00001-of-00002.safetensors",
|
||||
"FLUX.1-schnell/text_encoder_2/model-00002-of-00002.safetensors",
|
||||
"FLUX.1-schnell/text_encoder_2/model.safetensors.index.json",
|
||||
"FLUX.1-schnell/tokenizer/merges.txt",
|
||||
"FLUX.1-schnell/tokenizer/special_tokens_map.json",
|
||||
"FLUX.1-schnell/tokenizer/tokenizer_config.json",
|
||||
"FLUX.1-schnell/tokenizer/vocab.json",
|
||||
"FLUX.1-schnell/tokenizer_2/special_tokens_map.json",
|
||||
"FLUX.1-schnell/tokenizer_2/spiece.model",
|
||||
"FLUX.1-schnell/tokenizer_2/tokenizer.json",
|
||||
"FLUX.1-schnell/tokenizer_2/tokenizer_config.json",
|
||||
"FLUX.1-schnell/transformer/config.json",
|
||||
"FLUX.1-schnell/transformer/diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||
"FLUX.1-schnell/transformer/diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||
"FLUX.1-schnell/transformer/diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||
"FLUX.1-schnell/transformer/diffusion_pytorch_model.safetensors.index.json",
|
||||
"FLUX.1-schnell/vae/config.json",
|
||||
"FLUX.1-schnell/vae/diffusion_pytorch_model.safetensors",
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["variant", "expected_files"],
|
||||
[
|
||||
(
|
||||
ModelRepoVariant.Default,
|
||||
[
|
||||
"FLUX.1-schnell/model_index.json",
|
||||
"FLUX.1-schnell/scheduler/scheduler_config.json",
|
||||
"FLUX.1-schnell/text_encoder/config.json",
|
||||
"FLUX.1-schnell/text_encoder/model.safetensors",
|
||||
"FLUX.1-schnell/text_encoder_2/config.json",
|
||||
"FLUX.1-schnell/text_encoder_2/model-00001-of-00002.safetensors",
|
||||
"FLUX.1-schnell/text_encoder_2/model-00002-of-00002.safetensors",
|
||||
"FLUX.1-schnell/text_encoder_2/model.safetensors.index.json",
|
||||
"FLUX.1-schnell/tokenizer/merges.txt",
|
||||
"FLUX.1-schnell/tokenizer/special_tokens_map.json",
|
||||
"FLUX.1-schnell/tokenizer/tokenizer_config.json",
|
||||
"FLUX.1-schnell/tokenizer/vocab.json",
|
||||
"FLUX.1-schnell/tokenizer_2/special_tokens_map.json",
|
||||
"FLUX.1-schnell/tokenizer_2/spiece.model",
|
||||
"FLUX.1-schnell/tokenizer_2/tokenizer.json",
|
||||
"FLUX.1-schnell/tokenizer_2/tokenizer_config.json",
|
||||
"FLUX.1-schnell/transformer/config.json",
|
||||
"FLUX.1-schnell/transformer/diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||
"FLUX.1-schnell/transformer/diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||
"FLUX.1-schnell/transformer/diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||
"FLUX.1-schnell/transformer/diffusion_pytorch_model.safetensors.index.json",
|
||||
"FLUX.1-schnell/vae/config.json",
|
||||
"FLUX.1-schnell/vae/diffusion_pytorch_model.safetensors",
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_select_flux_schnell_files(
|
||||
flux_schnell_test_files: list[Path], variant: ModelRepoVariant, expected_files: list[str]
|
||||
) -> None:
|
||||
filtered_files = filter_files(flux_schnell_test_files, variant)
|
||||
assert set(filtered_files) == {Path(f) for f in expected_files}
|
||||
|
||||
Reference in New Issue
Block a user