mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-23 02:08:23 -05:00
Compare commits
14 Commits
v4.2.9.dev
...
ryan/flux
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a8a2fc106d | ||
|
|
d23ad1818d | ||
|
|
4181ab654b | ||
|
|
1c97360f9f | ||
|
|
74d6fceeb6 | ||
|
|
766ddc18dc | ||
|
|
e6ff7488a1 | ||
|
|
89a652cfcd | ||
|
|
b227b9059d | ||
|
|
3599a4a3e4 | ||
|
|
5dd619e137 | ||
|
|
7d447cbb88 | ||
|
|
3bbba7e4b1 | ||
|
|
b1845019fe |
278
invokeai/app/invocations/flux_text_to_image.py
Normal file
278
invokeai/app/invocations/flux_text_to_image.py
Normal file
@@ -0,0 +1,278 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
|
||||||
|
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
||||||
|
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
|
||||||
|
from optimum.quanto import qfloat8
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
||||||
|
from transformers.models.auto import AutoModelForTextEncoding
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||||
|
from invokeai.app.invocations.fields import InputField, WithBoard, WithMetadata
|
||||||
|
from invokeai.app.invocations.primitives import ImageOutput
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
from invokeai.backend.quantization.fast_quantized_diffusion_model import FastQuantizedDiffusersModel
|
||||||
|
from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel
|
||||||
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
|
TFluxModelKeys = Literal["flux-schnell"]
|
||||||
|
FLUX_MODELS: dict[TFluxModelKeys, str] = {"flux-schnell": "black-forest-labs/FLUX.1-schnell"}
|
||||||
|
|
||||||
|
|
||||||
|
class QuantizedFluxTransformer2DModel(FastQuantizedDiffusersModel):
|
||||||
|
base_class = FluxTransformer2DModel
|
||||||
|
|
||||||
|
|
||||||
|
class QuantizedModelForTextEncoding(FastQuantizedTransformersModel):
|
||||||
|
auto_class = AutoModelForTextEncoding
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"flux_text_to_image",
|
||||||
|
title="FLUX Text to Image",
|
||||||
|
tags=["image"],
|
||||||
|
category="image",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
|
class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||||
|
"""Text-to-image generation using a FLUX model."""
|
||||||
|
|
||||||
|
model: TFluxModelKeys = InputField(description="The FLUX model to use for text-to-image generation.")
|
||||||
|
use_8bit: bool = InputField(
|
||||||
|
default=False, description="Whether to quantize the transformer model to 8-bit precision."
|
||||||
|
)
|
||||||
|
positive_prompt: str = InputField(description="Positive prompt for text-to-image generation.")
|
||||||
|
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
|
||||||
|
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
|
||||||
|
num_steps: int = InputField(default=4, description="Number of diffusion steps.")
|
||||||
|
guidance: float = InputField(
|
||||||
|
default=4.0,
|
||||||
|
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images.",
|
||||||
|
)
|
||||||
|
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
model_path = context.models.download_and_cache_model(FLUX_MODELS[self.model])
|
||||||
|
|
||||||
|
t5_embeddings, clip_embeddings = self._encode_prompt(context, model_path)
|
||||||
|
latents = self._run_diffusion(context, model_path, clip_embeddings, t5_embeddings)
|
||||||
|
image = self._run_vae_decoding(context, model_path, latents)
|
||||||
|
image_dto = context.images.save(image=image)
|
||||||
|
return ImageOutput.build(image_dto)
|
||||||
|
|
||||||
|
def _encode_prompt(self, context: InvocationContext, flux_model_dir: Path) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# Determine the T5 max sequence length based on the model.
|
||||||
|
if self.model == "flux-schnell":
|
||||||
|
max_seq_len = 256
|
||||||
|
# elif self.model == "flux-dev":
|
||||||
|
# max_seq_len = 512
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown model: {self.model}")
|
||||||
|
|
||||||
|
# Load the CLIP tokenizer.
|
||||||
|
clip_tokenizer_path = flux_model_dir / "tokenizer"
|
||||||
|
clip_tokenizer = CLIPTokenizer.from_pretrained(clip_tokenizer_path, local_files_only=True)
|
||||||
|
assert isinstance(clip_tokenizer, CLIPTokenizer)
|
||||||
|
|
||||||
|
# Load the T5 tokenizer.
|
||||||
|
t5_tokenizer_path = flux_model_dir / "tokenizer_2"
|
||||||
|
t5_tokenizer = T5TokenizerFast.from_pretrained(t5_tokenizer_path, local_files_only=True)
|
||||||
|
assert isinstance(t5_tokenizer, T5TokenizerFast)
|
||||||
|
|
||||||
|
clip_text_encoder_path = flux_model_dir / "text_encoder"
|
||||||
|
t5_text_encoder_path = flux_model_dir / "text_encoder_2"
|
||||||
|
with (
|
||||||
|
context.models.load_local_model(
|
||||||
|
model_path=clip_text_encoder_path, loader=self._load_flux_text_encoder
|
||||||
|
) as clip_text_encoder,
|
||||||
|
context.models.load_local_model(
|
||||||
|
model_path=t5_text_encoder_path, loader=self._load_flux_text_encoder_2
|
||||||
|
) as t5_text_encoder,
|
||||||
|
):
|
||||||
|
assert isinstance(clip_text_encoder, CLIPTextModel)
|
||||||
|
assert isinstance(t5_text_encoder, T5EncoderModel)
|
||||||
|
pipeline = FluxPipeline(
|
||||||
|
scheduler=None,
|
||||||
|
vae=None,
|
||||||
|
text_encoder=clip_text_encoder,
|
||||||
|
tokenizer=clip_tokenizer,
|
||||||
|
text_encoder_2=t5_text_encoder,
|
||||||
|
tokenizer_2=t5_tokenizer,
|
||||||
|
transformer=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# prompt_embeds: T5 embeddings
|
||||||
|
# pooled_prompt_embeds: CLIP embeddings
|
||||||
|
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
|
||||||
|
prompt=self.positive_prompt,
|
||||||
|
prompt_2=self.positive_prompt,
|
||||||
|
device=TorchDevice.choose_torch_device(),
|
||||||
|
max_sequence_length=max_seq_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(prompt_embeds, torch.Tensor)
|
||||||
|
assert isinstance(pooled_prompt_embeds, torch.Tensor)
|
||||||
|
return prompt_embeds, pooled_prompt_embeds
|
||||||
|
|
||||||
|
def _run_diffusion(
|
||||||
|
self,
|
||||||
|
context: InvocationContext,
|
||||||
|
flux_model_dir: Path,
|
||||||
|
clip_embeddings: torch.Tensor,
|
||||||
|
t5_embeddings: torch.Tensor,
|
||||||
|
):
|
||||||
|
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(flux_model_dir / "scheduler", local_files_only=True)
|
||||||
|
|
||||||
|
# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
|
||||||
|
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
|
||||||
|
# if the cache is not empty.
|
||||||
|
context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
|
||||||
|
|
||||||
|
transformer_path = flux_model_dir / "transformer"
|
||||||
|
with context.models.load_local_model(
|
||||||
|
model_path=transformer_path, loader=self._load_flux_transformer
|
||||||
|
) as transformer:
|
||||||
|
assert isinstance(transformer, FluxTransformer2DModel)
|
||||||
|
|
||||||
|
flux_pipeline_with_transformer = FluxPipeline(
|
||||||
|
scheduler=scheduler,
|
||||||
|
vae=None,
|
||||||
|
text_encoder=None,
|
||||||
|
tokenizer=None,
|
||||||
|
text_encoder_2=None,
|
||||||
|
tokenizer_2=None,
|
||||||
|
transformer=transformer,
|
||||||
|
)
|
||||||
|
|
||||||
|
t5_embeddings = t5_embeddings.to(dtype=transformer.dtype)
|
||||||
|
clip_embeddings = clip_embeddings.to(dtype=transformer.dtype)
|
||||||
|
|
||||||
|
latents = flux_pipeline_with_transformer(
|
||||||
|
height=self.height,
|
||||||
|
width=self.width,
|
||||||
|
num_inference_steps=self.num_steps,
|
||||||
|
guidance_scale=self.guidance,
|
||||||
|
generator=torch.Generator().manual_seed(self.seed),
|
||||||
|
prompt_embeds=t5_embeddings,
|
||||||
|
pooled_prompt_embeds=clip_embeddings,
|
||||||
|
output_type="latent",
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
assert isinstance(latents, torch.Tensor)
|
||||||
|
return latents
|
||||||
|
|
||||||
|
def _run_vae_decoding(
|
||||||
|
self,
|
||||||
|
context: InvocationContext,
|
||||||
|
flux_model_dir: Path,
|
||||||
|
latents: torch.Tensor,
|
||||||
|
) -> Image.Image:
|
||||||
|
vae_path = flux_model_dir / "vae"
|
||||||
|
with context.models.load_local_model(model_path=vae_path, loader=self._load_flux_vae) as vae:
|
||||||
|
assert isinstance(vae, AutoencoderKL)
|
||||||
|
|
||||||
|
flux_pipeline_with_vae = FluxPipeline(
|
||||||
|
scheduler=None,
|
||||||
|
vae=vae,
|
||||||
|
text_encoder=None,
|
||||||
|
tokenizer=None,
|
||||||
|
text_encoder_2=None,
|
||||||
|
tokenizer_2=None,
|
||||||
|
transformer=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
latents = flux_pipeline_with_vae._unpack_latents(
|
||||||
|
latents, self.height, self.width, flux_pipeline_with_vae.vae_scale_factor
|
||||||
|
)
|
||||||
|
latents = (
|
||||||
|
latents / flux_pipeline_with_vae.vae.config.scaling_factor
|
||||||
|
) + flux_pipeline_with_vae.vae.config.shift_factor
|
||||||
|
latents = latents.to(dtype=vae.dtype)
|
||||||
|
image = flux_pipeline_with_vae.vae.decode(latents, return_dict=False)[0]
|
||||||
|
image = flux_pipeline_with_vae.image_processor.postprocess(image, output_type="pil")[0]
|
||||||
|
|
||||||
|
assert isinstance(image, Image.Image)
|
||||||
|
return image
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _load_flux_text_encoder(path: Path) -> CLIPTextModel:
|
||||||
|
model = CLIPTextModel.from_pretrained(path, local_files_only=True)
|
||||||
|
assert isinstance(model, CLIPTextModel)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _load_flux_text_encoder_2(self, path: Path) -> T5EncoderModel:
|
||||||
|
if self.use_8bit:
|
||||||
|
model_8bit_path = path / "quantized"
|
||||||
|
if model_8bit_path.exists():
|
||||||
|
# The quantized model exists, load it.
|
||||||
|
# TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like
|
||||||
|
# something that we should be able to make much faster.
|
||||||
|
q_model = QuantizedModelForTextEncoding.from_pretrained(model_8bit_path)
|
||||||
|
|
||||||
|
# Access the underlying wrapped model.
|
||||||
|
# We access the wrapped model, even though it is private, because it simplifies the type checking by
|
||||||
|
# always returning a T5EncoderModel from this function.
|
||||||
|
model = q_model._wrapped
|
||||||
|
else:
|
||||||
|
# The quantized model does not exist yet, quantize and save it.
|
||||||
|
# TODO(ryand): dtype?
|
||||||
|
model = T5EncoderModel.from_pretrained(path, local_files_only=True)
|
||||||
|
assert isinstance(model, T5EncoderModel)
|
||||||
|
|
||||||
|
q_model = QuantizedModelForTextEncoding.quantize(model, weights=qfloat8)
|
||||||
|
|
||||||
|
model_8bit_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
q_model.save_pretrained(model_8bit_path)
|
||||||
|
|
||||||
|
# (See earlier comment about accessing the wrapped model.)
|
||||||
|
model = q_model._wrapped
|
||||||
|
else:
|
||||||
|
model = T5EncoderModel.from_pretrained(path, local_files_only=True)
|
||||||
|
|
||||||
|
assert isinstance(model, T5EncoderModel)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _load_flux_transformer(self, path: Path) -> FluxTransformer2DModel:
|
||||||
|
if self.use_8bit:
|
||||||
|
model_8bit_path = path / "quantized"
|
||||||
|
if model_8bit_path.exists():
|
||||||
|
# The quantized model exists, load it.
|
||||||
|
# TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like
|
||||||
|
# something that we should be able to make much faster.
|
||||||
|
q_model = QuantizedFluxTransformer2DModel.from_pretrained(model_8bit_path)
|
||||||
|
|
||||||
|
# Access the underlying wrapped model.
|
||||||
|
# We access the wrapped model, even though it is private, because it simplifies the type checking by
|
||||||
|
# always returning a FluxTransformer2DModel from this function.
|
||||||
|
model = q_model._wrapped
|
||||||
|
else:
|
||||||
|
# The quantized model does not exist yet, quantize and save it.
|
||||||
|
# TODO(ryand): Loading in float16 and then quantizing seems to result in NaNs. In order to run this on
|
||||||
|
# GPUs that don't support bfloat16, we would need to host the quantized model instead of generating it
|
||||||
|
# here.
|
||||||
|
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
|
||||||
|
assert isinstance(model, FluxTransformer2DModel)
|
||||||
|
|
||||||
|
q_model = QuantizedFluxTransformer2DModel.quantize(model, weights=qfloat8)
|
||||||
|
|
||||||
|
model_8bit_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
q_model.save_pretrained(model_8bit_path)
|
||||||
|
|
||||||
|
# (See earlier comment about accessing the wrapped model.)
|
||||||
|
model = q_model._wrapped
|
||||||
|
else:
|
||||||
|
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
assert isinstance(model, FluxTransformer2DModel)
|
||||||
|
return model
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _load_flux_vae(path: Path) -> AutoencoderKL:
|
||||||
|
model = AutoencoderKL.from_pretrained(path, local_files_only=True)
|
||||||
|
assert isinstance(model, AutoencoderKL)
|
||||||
|
return model
|
||||||
129
invokeai/backend/load_flux_model.py
Normal file
129
invokeai/backend/load_flux_model.py
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from diffusers.models.model_loading_utils import load_state_dict
|
||||||
|
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
||||||
|
from diffusers.utils import (
|
||||||
|
CONFIG_NAME,
|
||||||
|
SAFE_WEIGHTS_INDEX_NAME,
|
||||||
|
SAFETENSORS_WEIGHTS_NAME,
|
||||||
|
_get_checkpoint_shard_files,
|
||||||
|
is_accelerate_available,
|
||||||
|
)
|
||||||
|
from optimum.quanto import qfloat8
|
||||||
|
from optimum.quanto.models import QuantizedDiffusersModel
|
||||||
|
from optimum.quanto.models.shared_dict import ShardedStateDict
|
||||||
|
|
||||||
|
from invokeai.backend.requantize import requantize
|
||||||
|
|
||||||
|
|
||||||
|
class QuantizedFluxTransformer2DModel(QuantizedDiffusersModel):
|
||||||
|
base_class = FluxTransformer2DModel
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike]):
|
||||||
|
if cls.base_class is None:
|
||||||
|
raise ValueError("The `base_class` attribute needs to be configured.")
|
||||||
|
|
||||||
|
if not is_accelerate_available():
|
||||||
|
raise ValueError("Reloading a quantized diffusers model requires the accelerate library.")
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
|
||||||
|
if os.path.isdir(model_name_or_path):
|
||||||
|
# Look for a quantization map
|
||||||
|
qmap_path = os.path.join(model_name_or_path, cls._qmap_name())
|
||||||
|
if not os.path.exists(qmap_path):
|
||||||
|
raise ValueError(f"No quantization map found in {model_name_or_path}: is this a quantized model ?")
|
||||||
|
|
||||||
|
# Look for original model config file.
|
||||||
|
model_config_path = os.path.join(model_name_or_path, CONFIG_NAME)
|
||||||
|
if not os.path.exists(model_config_path):
|
||||||
|
raise ValueError(f"{CONFIG_NAME} not found in {model_name_or_path}.")
|
||||||
|
|
||||||
|
with open(qmap_path, "r", encoding="utf-8") as f:
|
||||||
|
qmap = json.load(f)
|
||||||
|
|
||||||
|
with open(model_config_path, "r", encoding="utf-8") as f:
|
||||||
|
original_model_cls_name = json.load(f)["_class_name"]
|
||||||
|
configured_cls_name = cls.base_class.__name__
|
||||||
|
if configured_cls_name != original_model_cls_name:
|
||||||
|
raise ValueError(
|
||||||
|
f"Configured base class ({configured_cls_name}) differs from what was derived from the provided configuration ({original_model_cls_name})."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create an empty model
|
||||||
|
config = cls.base_class.load_config(model_name_or_path)
|
||||||
|
with init_empty_weights():
|
||||||
|
model = cls.base_class.from_config(config)
|
||||||
|
|
||||||
|
# Look for the index of a sharded checkpoint
|
||||||
|
checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
|
||||||
|
if os.path.exists(checkpoint_file):
|
||||||
|
# Convert the checkpoint path to a list of shards
|
||||||
|
_, sharded_metadata = _get_checkpoint_shard_files(model_name_or_path, checkpoint_file)
|
||||||
|
# Create a mapping for the sharded safetensor files
|
||||||
|
state_dict = ShardedStateDict(model_name_or_path, sharded_metadata["weight_map"])
|
||||||
|
else:
|
||||||
|
# Look for a single checkpoint file
|
||||||
|
checkpoint_file = os.path.join(model_name_or_path, SAFETENSORS_WEIGHTS_NAME)
|
||||||
|
if not os.path.exists(checkpoint_file):
|
||||||
|
raise ValueError(f"No safetensor weights found in {model_name_or_path}.")
|
||||||
|
# Get state_dict from model checkpoint
|
||||||
|
state_dict = load_state_dict(checkpoint_file)
|
||||||
|
|
||||||
|
# Requantize and load quantized weights from state_dict
|
||||||
|
requantize(model, state_dict=state_dict, quantization_map=qmap)
|
||||||
|
model.eval()
|
||||||
|
return cls(model)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Reloading quantized models directly from the hub is not supported yet.")
|
||||||
|
|
||||||
|
|
||||||
|
def load_flux_transformer(path: Path) -> FluxTransformer2DModel:
|
||||||
|
# model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
|
||||||
|
model_8bit_path = path / "quantized"
|
||||||
|
if model_8bit_path.exists():
|
||||||
|
# The quantized model exists, load it.
|
||||||
|
# TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like
|
||||||
|
# something that we should be able to make much faster.
|
||||||
|
q_model = QuantizedFluxTransformer2DModel.from_pretrained(model_8bit_path)
|
||||||
|
|
||||||
|
# Access the underlying wrapped model.
|
||||||
|
# We access the wrapped model, even though it is private, because it simplifies the type checking by
|
||||||
|
# always returning a FluxTransformer2DModel from this function.
|
||||||
|
model = q_model._wrapped
|
||||||
|
else:
|
||||||
|
# The quantized model does not exist yet, quantize and save it.
|
||||||
|
# TODO(ryand): Loading in float16 and then quantizing seems to result in NaNs. In order to run this on
|
||||||
|
# GPUs that don't support bfloat16, we would need to host the quantized model instead of generating it
|
||||||
|
# here.
|
||||||
|
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
|
||||||
|
assert isinstance(model, FluxTransformer2DModel)
|
||||||
|
|
||||||
|
q_model = QuantizedFluxTransformer2DModel.quantize(model, weights=qfloat8)
|
||||||
|
|
||||||
|
model_8bit_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
q_model.save_pretrained(model_8bit_path)
|
||||||
|
|
||||||
|
# (See earlier comment about accessing the wrapped model.)
|
||||||
|
model = q_model._wrapped
|
||||||
|
|
||||||
|
assert isinstance(model, FluxTransformer2DModel)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
start = time.time()
|
||||||
|
model = load_flux_transformer(
|
||||||
|
Path("/data/invokeai/models/.download_cache/black-forest-labs_flux.1-schnell/FLUX.1-schnell/transformer/")
|
||||||
|
)
|
||||||
|
print(f"Time to load: {time.time() - start}s")
|
||||||
|
print("hi")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -54,6 +54,7 @@ def filter_files(
|
|||||||
"lora_weights.safetensors",
|
"lora_weights.safetensors",
|
||||||
"weights.pb",
|
"weights.pb",
|
||||||
"onnx_data",
|
"onnx_data",
|
||||||
|
"spiece.model", # Added for `black-forest-labs/FLUX.1-schnell`.
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
paths.append(file)
|
paths.append(file)
|
||||||
@@ -62,7 +63,7 @@ def filter_files(
|
|||||||
# downloading random checkpoints that might also be in the repo. However there is no guarantee
|
# downloading random checkpoints that might also be in the repo. However there is no guarantee
|
||||||
# that a checkpoint doesn't contain "model" in its name, and no guarantee that future diffusers models
|
# that a checkpoint doesn't contain "model" in its name, and no guarantee that future diffusers models
|
||||||
# will adhere to this naming convention, so this is an area to be careful of.
|
# will adhere to this naming convention, so this is an area to be careful of.
|
||||||
elif re.search(r"model(\.[^.]+)?\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name):
|
elif re.search(r"model.*\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name):
|
||||||
paths.append(file)
|
paths.append(file)
|
||||||
|
|
||||||
# limit search to subfolder if requested
|
# limit search to subfolder if requested
|
||||||
@@ -97,7 +98,9 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
|
|||||||
if variant == ModelRepoVariant.Flax:
|
if variant == ModelRepoVariant.Flax:
|
||||||
result.add(path)
|
result.add(path)
|
||||||
|
|
||||||
elif path.suffix in [".json", ".txt"]:
|
# Note: '.model' was added to support:
|
||||||
|
# https://huggingface.co/black-forest-labs/FLUX.1-schnell/blob/768d12a373ed5cc9ef9a9dea7504dc09fcc14842/tokenizer_2/spiece.model
|
||||||
|
elif path.suffix in [".json", ".txt", ".model"]:
|
||||||
result.add(path)
|
result.add(path)
|
||||||
|
|
||||||
elif variant in [
|
elif variant in [
|
||||||
@@ -140,6 +143,23 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
for candidate_list in subfolder_weights.values():
|
for candidate_list in subfolder_weights.values():
|
||||||
|
# Check if at least one of the files has the explicit fp16 variant.
|
||||||
|
at_least_one_fp16 = False
|
||||||
|
for candidate in candidate_list:
|
||||||
|
if len(candidate.path.suffixes) == 2 and candidate.path.suffixes[0] == ".fp16":
|
||||||
|
at_least_one_fp16 = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if not at_least_one_fp16:
|
||||||
|
# If none of the candidates in this candidate_list have the explicit fp16 variant label, then this
|
||||||
|
# candidate_list probably doesn't adhere to the variant naming convention that we expected. In this case,
|
||||||
|
# we'll simply keep all the candidates. An example of a model that hits this case is
|
||||||
|
# `black-forest-labs/FLUX.1-schnell` (as of commit 012d2fd).
|
||||||
|
for candidate in candidate_list:
|
||||||
|
result.add(candidate.path)
|
||||||
|
|
||||||
|
# The candidate_list seems to have the expected variant naming convention. We'll select the highest scoring
|
||||||
|
# candidate.
|
||||||
highest_score_candidate = max(candidate_list, key=lambda candidate: candidate.score)
|
highest_score_candidate = max(candidate_list, key=lambda candidate: candidate.score)
|
||||||
if highest_score_candidate:
|
if highest_score_candidate:
|
||||||
result.add(highest_score_candidate.path)
|
result.add(highest_score_candidate.path)
|
||||||
|
|||||||
@@ -0,0 +1,77 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from diffusers.models.model_loading_utils import load_state_dict
|
||||||
|
from diffusers.utils import (
|
||||||
|
CONFIG_NAME,
|
||||||
|
SAFE_WEIGHTS_INDEX_NAME,
|
||||||
|
SAFETENSORS_WEIGHTS_NAME,
|
||||||
|
_get_checkpoint_shard_files,
|
||||||
|
is_accelerate_available,
|
||||||
|
)
|
||||||
|
from optimum.quanto.models import QuantizedDiffusersModel
|
||||||
|
from optimum.quanto.models.shared_dict import ShardedStateDict
|
||||||
|
|
||||||
|
from invokeai.backend.requantize import requantize
|
||||||
|
|
||||||
|
|
||||||
|
class FastQuantizedDiffusersModel(QuantizedDiffusersModel):
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike]):
|
||||||
|
"""We override the `from_pretrained()` method in order to use our custom `requantize()` implementation."""
|
||||||
|
if cls.base_class is None:
|
||||||
|
raise ValueError("The `base_class` attribute needs to be configured.")
|
||||||
|
|
||||||
|
if not is_accelerate_available():
|
||||||
|
raise ValueError("Reloading a quantized diffusers model requires the accelerate library.")
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
|
||||||
|
if os.path.isdir(model_name_or_path):
|
||||||
|
# Look for a quantization map
|
||||||
|
qmap_path = os.path.join(model_name_or_path, cls._qmap_name())
|
||||||
|
if not os.path.exists(qmap_path):
|
||||||
|
raise ValueError(f"No quantization map found in {model_name_or_path}: is this a quantized model ?")
|
||||||
|
|
||||||
|
# Look for original model config file.
|
||||||
|
model_config_path = os.path.join(model_name_or_path, CONFIG_NAME)
|
||||||
|
if not os.path.exists(model_config_path):
|
||||||
|
raise ValueError(f"{CONFIG_NAME} not found in {model_name_or_path}.")
|
||||||
|
|
||||||
|
with open(qmap_path, "r", encoding="utf-8") as f:
|
||||||
|
qmap = json.load(f)
|
||||||
|
|
||||||
|
with open(model_config_path, "r", encoding="utf-8") as f:
|
||||||
|
original_model_cls_name = json.load(f)["_class_name"]
|
||||||
|
configured_cls_name = cls.base_class.__name__
|
||||||
|
if configured_cls_name != original_model_cls_name:
|
||||||
|
raise ValueError(
|
||||||
|
f"Configured base class ({configured_cls_name}) differs from what was derived from the provided configuration ({original_model_cls_name})."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create an empty model
|
||||||
|
config = cls.base_class.load_config(model_name_or_path)
|
||||||
|
with init_empty_weights():
|
||||||
|
model = cls.base_class.from_config(config)
|
||||||
|
|
||||||
|
# Look for the index of a sharded checkpoint
|
||||||
|
checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
|
||||||
|
if os.path.exists(checkpoint_file):
|
||||||
|
# Convert the checkpoint path to a list of shards
|
||||||
|
_, sharded_metadata = _get_checkpoint_shard_files(model_name_or_path, checkpoint_file)
|
||||||
|
# Create a mapping for the sharded safetensor files
|
||||||
|
state_dict = ShardedStateDict(model_name_or_path, sharded_metadata["weight_map"])
|
||||||
|
else:
|
||||||
|
# Look for a single checkpoint file
|
||||||
|
checkpoint_file = os.path.join(model_name_or_path, SAFETENSORS_WEIGHTS_NAME)
|
||||||
|
if not os.path.exists(checkpoint_file):
|
||||||
|
raise ValueError(f"No safetensor weights found in {model_name_or_path}.")
|
||||||
|
# Get state_dict from model checkpoint
|
||||||
|
state_dict = load_state_dict(checkpoint_file)
|
||||||
|
|
||||||
|
# Requantize and load quantized weights from state_dict
|
||||||
|
requantize(model, state_dict=state_dict, quantization_map=qmap)
|
||||||
|
model.eval()
|
||||||
|
return cls(model)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Reloading quantized models directly from the hub is not supported yet.")
|
||||||
@@ -0,0 +1,61 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from optimum.quanto.models import QuantizedTransformersModel
|
||||||
|
from optimum.quanto.models.shared_dict import ShardedStateDict
|
||||||
|
from transformers import AutoConfig
|
||||||
|
from transformers.modeling_utils import get_checkpoint_shard_files, load_state_dict
|
||||||
|
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, is_accelerate_available
|
||||||
|
|
||||||
|
from invokeai.backend.requantize import requantize
|
||||||
|
|
||||||
|
|
||||||
|
class FastQuantizedTransformersModel(QuantizedTransformersModel):
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike]):
|
||||||
|
"""We override the `from_pretrained()` method in order to use our custom `requantize()` implementation."""
|
||||||
|
if cls.auto_class is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Quantized models cannot be reloaded using {cls}: use a specialized quantized class such as QuantizedModelForCausalLM instead."
|
||||||
|
)
|
||||||
|
if not is_accelerate_available():
|
||||||
|
raise ValueError("Reloading a quantized transformers model requires the accelerate library.")
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
|
||||||
|
if os.path.isdir(model_name_or_path):
|
||||||
|
# Look for a quantization map
|
||||||
|
qmap_path = os.path.join(model_name_or_path, cls._qmap_name())
|
||||||
|
if not os.path.exists(qmap_path):
|
||||||
|
raise ValueError(f"No quantization map found in {model_name_or_path}: is this a quantized model ?")
|
||||||
|
with open(qmap_path, "r", encoding="utf-8") as f:
|
||||||
|
qmap = json.load(f)
|
||||||
|
# Create an empty model
|
||||||
|
config = AutoConfig.from_pretrained(model_name_or_path)
|
||||||
|
with init_empty_weights():
|
||||||
|
model = cls.auto_class.from_config(config)
|
||||||
|
# Look for the index of a sharded checkpoint
|
||||||
|
checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
|
||||||
|
if os.path.exists(checkpoint_file):
|
||||||
|
# Convert the checkpoint path to a list of shards
|
||||||
|
checkpoint_file, sharded_metadata = get_checkpoint_shard_files(model_name_or_path, checkpoint_file)
|
||||||
|
# Create a mapping for the sharded safetensor files
|
||||||
|
state_dict = ShardedStateDict(model_name_or_path, sharded_metadata["weight_map"])
|
||||||
|
else:
|
||||||
|
# Look for a single checkpoint file
|
||||||
|
checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_NAME)
|
||||||
|
if not os.path.exists(checkpoint_file):
|
||||||
|
raise ValueError(f"No safetensor weights found in {model_name_or_path}.")
|
||||||
|
# Get state_dict from model checkpoint
|
||||||
|
state_dict = load_state_dict(checkpoint_file)
|
||||||
|
# Requantize and load quantized weights from state_dict
|
||||||
|
requantize(model, state_dict=state_dict, quantization_map=qmap)
|
||||||
|
if getattr(model.config, "tie_word_embeddings", True):
|
||||||
|
# Tie output weight embeddings to input weight embeddings
|
||||||
|
# Note that if they were quantized they would NOT be tied
|
||||||
|
model.tie_weights()
|
||||||
|
# Set model in evaluation mode as it is done in transformers
|
||||||
|
model.eval()
|
||||||
|
return cls(model)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Reloading quantized models directly from the hub is not supported yet.")
|
||||||
53
invokeai/backend/requantize.py
Normal file
53
invokeai/backend/requantize.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from optimum.quanto.quantize import _quantize_submodule
|
||||||
|
|
||||||
|
# def custom_freeze(model: torch.nn.Module):
|
||||||
|
# for name, m in model.named_modules():
|
||||||
|
# if isinstance(m, QModuleMixin):
|
||||||
|
# m.weight =
|
||||||
|
# m.freeze()
|
||||||
|
|
||||||
|
|
||||||
|
def requantize(
|
||||||
|
model: torch.nn.Module,
|
||||||
|
state_dict: Dict[str, Any],
|
||||||
|
quantization_map: Dict[str, Dict[str, str]],
|
||||||
|
device: torch.device = None,
|
||||||
|
):
|
||||||
|
if device is None:
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
if device.type == "meta":
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
# Quantize the model with parameters from the quantization map
|
||||||
|
for name, m in model.named_modules():
|
||||||
|
qconfig = quantization_map.get(name, None)
|
||||||
|
if qconfig is not None:
|
||||||
|
weights = qconfig["weights"]
|
||||||
|
if weights == "none":
|
||||||
|
weights = None
|
||||||
|
activations = qconfig["activations"]
|
||||||
|
if activations == "none":
|
||||||
|
activations = None
|
||||||
|
_quantize_submodule(model, name, m, weights=weights, activations=activations)
|
||||||
|
|
||||||
|
# Move model parameters and buffers to CPU before materializing quantized weights
|
||||||
|
for name, m in model.named_modules():
|
||||||
|
|
||||||
|
def move_tensor(t, device):
|
||||||
|
if t.device.type == "meta":
|
||||||
|
return torch.empty_like(t, device=device)
|
||||||
|
return t.to(device)
|
||||||
|
|
||||||
|
for name, param in m.named_parameters(recurse=False):
|
||||||
|
setattr(m, name, torch.nn.Parameter(move_tensor(param, "cpu")))
|
||||||
|
for name, param in m.named_buffers(recurse=False):
|
||||||
|
setattr(m, name, move_tensor(param, "cpu"))
|
||||||
|
# Freeze model and move to target device
|
||||||
|
# freeze(model)
|
||||||
|
# model.to(device)
|
||||||
|
|
||||||
|
# Load the quantized model weights
|
||||||
|
model.load_state_dict(state_dict, strict=False)
|
||||||
@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|||||||
import diffusers
|
import diffusers
|
||||||
import torch
|
import torch
|
||||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||||
from diffusers.loaders import FromOriginalControlNetMixin
|
from diffusers.loaders.single_file_model import FromOriginalModelMixin
|
||||||
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
|
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
|
||||||
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
|
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
|
||||||
from diffusers.models.embeddings import (
|
from diffusers.models.embeddings import (
|
||||||
@@ -32,7 +32,7 @@ from invokeai.backend.util.logging import InvokeAILogger
|
|||||||
logger = InvokeAILogger.get_logger(__name__)
|
logger = InvokeAILogger.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
|
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||||
"""
|
"""
|
||||||
A ControlNet model.
|
A ControlNet model.
|
||||||
|
|
||||||
|
|||||||
@@ -33,31 +33,35 @@ classifiers = [
|
|||||||
]
|
]
|
||||||
dependencies = [
|
dependencies = [
|
||||||
# Core generation dependencies, pinned for reproducible builds.
|
# Core generation dependencies, pinned for reproducible builds.
|
||||||
"accelerate==0.30.1",
|
"accelerate==0.33.0",
|
||||||
"clip_anytorch==2.6.0", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
"clip_anytorch==2.6.0", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
||||||
"compel==2.0.2",
|
"compel==2.0.2",
|
||||||
"controlnet-aux==0.0.7",
|
"controlnet-aux==0.0.7",
|
||||||
"diffusers[torch]==0.27.2",
|
# TODO(ryand): Bump this once the next diffusers release is ready.
|
||||||
|
"diffusers[torch] @ git+https://github.com/huggingface/diffusers.git@4c6152c2fb0ade468aadb417102605a07a8635d3",
|
||||||
"invisible-watermark==0.2.0", # needed to install SDXL base and refiner using their repo_ids
|
"invisible-watermark==0.2.0", # needed to install SDXL base and refiner using their repo_ids
|
||||||
"mediapipe==0.10.7", # needed for "mediapipeface" controlnet model
|
"mediapipe==0.10.7", # needed for "mediapipeface" controlnet model
|
||||||
"numpy==1.26.4", # >1.24.0 is needed to use the 'strict' argument to np.testing.assert_array_equal()
|
"numpy==1.26.4", # >1.24.0 is needed to use the 'strict' argument to np.testing.assert_array_equal()
|
||||||
"onnx==1.15.0",
|
"onnx==1.15.0",
|
||||||
"onnxruntime==1.16.3",
|
"onnxruntime==1.16.3",
|
||||||
"opencv-python==4.9.0.80",
|
"opencv-python==4.9.0.80",
|
||||||
|
"optimum-quanto==0.2.4",
|
||||||
"pytorch-lightning==2.1.3",
|
"pytorch-lightning==2.1.3",
|
||||||
"safetensors==0.4.3",
|
"safetensors==0.4.3",
|
||||||
|
# sentencepiece is required to load T5TokenizerFast (used by FLUX).
|
||||||
|
"sentencepiece==0.2.0",
|
||||||
"spandrel==0.3.4",
|
"spandrel==0.3.4",
|
||||||
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
|
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
|
||||||
"torch==2.2.2",
|
"torch==2.4.0",
|
||||||
"torchmetrics==0.11.4",
|
"torchmetrics==0.11.4",
|
||||||
"torchsde==0.2.6",
|
"torchsde==0.2.6",
|
||||||
"torchvision==0.17.2",
|
"torchvision==0.19.0",
|
||||||
"transformers==4.41.1",
|
"transformers==4.41.1",
|
||||||
|
|
||||||
# Core application dependencies, pinned for reproducible builds.
|
# Core application dependencies, pinned for reproducible builds.
|
||||||
"fastapi-events==0.11.1",
|
"fastapi-events==0.11.1",
|
||||||
"fastapi==0.111.0",
|
"fastapi==0.111.0",
|
||||||
"huggingface-hub==0.23.1",
|
"huggingface-hub==0.24.5",
|
||||||
"pydantic-settings==2.2.1",
|
"pydantic-settings==2.2.1",
|
||||||
"pydantic==2.7.2",
|
"pydantic==2.7.2",
|
||||||
"python-socketio==5.11.1",
|
"python-socketio==5.11.1",
|
||||||
|
|||||||
@@ -326,3 +326,80 @@ def test_select_multiple_weights(
|
|||||||
) -> None:
|
) -> None:
|
||||||
filtered_files = filter_files(sd15_test_files, variant)
|
filtered_files = filter_files(sd15_test_files, variant)
|
||||||
assert set(filtered_files) == {Path(f) for f in expected_files}
|
assert set(filtered_files) == {Path(f) for f in expected_files}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def flux_schnell_test_files() -> list[Path]:
|
||||||
|
return [
|
||||||
|
Path(f)
|
||||||
|
for f in [
|
||||||
|
"FLUX.1-schnell/.gitattributes",
|
||||||
|
"FLUX.1-schnell/README.md",
|
||||||
|
"FLUX.1-schnell/ae.safetensors",
|
||||||
|
"FLUX.1-schnell/flux1-schnell.safetensors",
|
||||||
|
"FLUX.1-schnell/model_index.json",
|
||||||
|
"FLUX.1-schnell/scheduler/scheduler_config.json",
|
||||||
|
"FLUX.1-schnell/schnell_grid.jpeg",
|
||||||
|
"FLUX.1-schnell/text_encoder/config.json",
|
||||||
|
"FLUX.1-schnell/text_encoder/model.safetensors",
|
||||||
|
"FLUX.1-schnell/text_encoder_2/config.json",
|
||||||
|
"FLUX.1-schnell/text_encoder_2/model-00001-of-00002.safetensors",
|
||||||
|
"FLUX.1-schnell/text_encoder_2/model-00002-of-00002.safetensors",
|
||||||
|
"FLUX.1-schnell/text_encoder_2/model.safetensors.index.json",
|
||||||
|
"FLUX.1-schnell/tokenizer/merges.txt",
|
||||||
|
"FLUX.1-schnell/tokenizer/special_tokens_map.json",
|
||||||
|
"FLUX.1-schnell/tokenizer/tokenizer_config.json",
|
||||||
|
"FLUX.1-schnell/tokenizer/vocab.json",
|
||||||
|
"FLUX.1-schnell/tokenizer_2/special_tokens_map.json",
|
||||||
|
"FLUX.1-schnell/tokenizer_2/spiece.model",
|
||||||
|
"FLUX.1-schnell/tokenizer_2/tokenizer.json",
|
||||||
|
"FLUX.1-schnell/tokenizer_2/tokenizer_config.json",
|
||||||
|
"FLUX.1-schnell/transformer/config.json",
|
||||||
|
"FLUX.1-schnell/transformer/diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||||
|
"FLUX.1-schnell/transformer/diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||||
|
"FLUX.1-schnell/transformer/diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||||
|
"FLUX.1-schnell/transformer/diffusion_pytorch_model.safetensors.index.json",
|
||||||
|
"FLUX.1-schnell/vae/config.json",
|
||||||
|
"FLUX.1-schnell/vae/diffusion_pytorch_model.safetensors",
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
["variant", "expected_files"],
|
||||||
|
[
|
||||||
|
(
|
||||||
|
ModelRepoVariant.Default,
|
||||||
|
[
|
||||||
|
"FLUX.1-schnell/model_index.json",
|
||||||
|
"FLUX.1-schnell/scheduler/scheduler_config.json",
|
||||||
|
"FLUX.1-schnell/text_encoder/config.json",
|
||||||
|
"FLUX.1-schnell/text_encoder/model.safetensors",
|
||||||
|
"FLUX.1-schnell/text_encoder_2/config.json",
|
||||||
|
"FLUX.1-schnell/text_encoder_2/model-00001-of-00002.safetensors",
|
||||||
|
"FLUX.1-schnell/text_encoder_2/model-00002-of-00002.safetensors",
|
||||||
|
"FLUX.1-schnell/text_encoder_2/model.safetensors.index.json",
|
||||||
|
"FLUX.1-schnell/tokenizer/merges.txt",
|
||||||
|
"FLUX.1-schnell/tokenizer/special_tokens_map.json",
|
||||||
|
"FLUX.1-schnell/tokenizer/tokenizer_config.json",
|
||||||
|
"FLUX.1-schnell/tokenizer/vocab.json",
|
||||||
|
"FLUX.1-schnell/tokenizer_2/special_tokens_map.json",
|
||||||
|
"FLUX.1-schnell/tokenizer_2/spiece.model",
|
||||||
|
"FLUX.1-schnell/tokenizer_2/tokenizer.json",
|
||||||
|
"FLUX.1-schnell/tokenizer_2/tokenizer_config.json",
|
||||||
|
"FLUX.1-schnell/transformer/config.json",
|
||||||
|
"FLUX.1-schnell/transformer/diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||||
|
"FLUX.1-schnell/transformer/diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||||
|
"FLUX.1-schnell/transformer/diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||||
|
"FLUX.1-schnell/transformer/diffusion_pytorch_model.safetensors.index.json",
|
||||||
|
"FLUX.1-schnell/vae/config.json",
|
||||||
|
"FLUX.1-schnell/vae/diffusion_pytorch_model.safetensors",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_select_flux_schnell_files(
|
||||||
|
flux_schnell_test_files: list[Path], variant: ModelRepoVariant, expected_files: list[str]
|
||||||
|
) -> None:
|
||||||
|
filtered_files = filter_files(flux_schnell_test_files, variant)
|
||||||
|
assert set(filtered_files) == {Path(f) for f in expected_files}
|
||||||
|
|||||||
Reference in New Issue
Block a user