mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-14 04:38:07 -05:00
Initial GGUF support for flux models
This commit is contained in:
committed by
Kent Keirsey
parent
950c9f5d0c
commit
2bfb0ddff5
@@ -114,6 +114,7 @@ class ModelFormat(str, Enum):
|
||||
T5Encoder = "t5_encoder"
|
||||
BnbQuantizedLlmInt8b = "bnb_quantized_int8b"
|
||||
BnbQuantizednf4b = "bnb_quantized_nf4b"
|
||||
GGUFQuantized = "gguf_quantized"
|
||||
|
||||
|
||||
class SchedulerPredictionType(str, Enum):
|
||||
@@ -197,7 +198,7 @@ class ModelConfigBase(BaseModel):
|
||||
class CheckpointConfigBase(ModelConfigBase):
|
||||
"""Model config for checkpoint-style models."""
|
||||
|
||||
format: Literal[ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b] = Field(
|
||||
format: Literal[ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b, ModelFormat.GGUFQuantized] = Field(
|
||||
description="Format of the provided checkpoint model", default=ModelFormat.Checkpoint
|
||||
)
|
||||
config_path: str = Field(description="path to the checkpoint model config file")
|
||||
@@ -363,6 +364,21 @@ class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase):
|
||||
return Tag(f"{ModelType.Main.value}.{ModelFormat.BnbQuantizednf4b.value}")
|
||||
|
||||
|
||||
class MainGGUFCheckpointConfig(CheckpointConfigBase, MainConfigBase):
|
||||
"""Model config for main checkpoint models."""
|
||||
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||
upcast_attention: bool = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.format = ModelFormat.GGUFQuantized
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.Main.value}.{ModelFormat.GGUFQuantized.value}")
|
||||
|
||||
|
||||
class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase):
|
||||
"""Model config for main diffusers models."""
|
||||
|
||||
@@ -466,6 +482,7 @@ AnyModelConfig = Annotated[
|
||||
Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
|
||||
Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
|
||||
Annotated[MainBnbQuantized4bCheckpointConfig, MainBnbQuantized4bCheckpointConfig.get_tag()],
|
||||
Annotated[MainGGUFCheckpointConfig, MainGGUFCheckpointConfig.get_tag()],
|
||||
Annotated[VAEDiffusersConfig, VAEDiffusersConfig.get_tag()],
|
||||
Annotated[VAECheckpointConfig, VAECheckpointConfig.get_tag()],
|
||||
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
|
||||
|
||||
@@ -26,6 +26,7 @@ from invokeai.backend.model_manager.config import (
|
||||
CLIPEmbedDiffusersConfig,
|
||||
MainBnbQuantized4bCheckpointConfig,
|
||||
MainCheckpointConfig,
|
||||
MainGGUFCheckpointConfig,
|
||||
T5EncoderBnbQuantizedLlmInt8bConfig,
|
||||
T5EncoderConfig,
|
||||
VAECheckpointConfig,
|
||||
@@ -35,6 +36,8 @@ from invokeai.backend.model_manager.load.model_loader_registry import ModelLoade
|
||||
from invokeai.backend.model_manager.util.model_util import (
|
||||
convert_bundle_to_flux_transformer_checkpoint,
|
||||
)
|
||||
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
|
||||
from invokeai.backend.quantization.gguf.torch_patcher import GGUFPatcher
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
try:
|
||||
@@ -204,6 +207,50 @@ class FluxCheckpointModel(ModelLoader):
|
||||
return model
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.GGUFQuantized)
|
||||
class FluxGGUFCheckpointModel(ModelLoader):
|
||||
"""Class to load GGUF main models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not isinstance(config, CheckpointConfigBase):
|
||||
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
|
||||
|
||||
match submodel_type:
|
||||
case SubModelType.Transformer:
|
||||
return self._load_from_singlefile(config)
|
||||
|
||||
raise ValueError(
|
||||
f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
|
||||
)
|
||||
|
||||
def _load_from_singlefile(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
) -> AnyModel:
|
||||
assert isinstance(config, MainGGUFCheckpointConfig)
|
||||
model_path = Path(config.path)
|
||||
|
||||
with SilenceWarnings(), GGUFPatcher().wrap():
|
||||
# Load the state dict and patcher
|
||||
sd = gguf_sd_loader(model_path)
|
||||
# Initialize the model
|
||||
model = Flux(params[config.config_path])
|
||||
|
||||
# Calculate new state dictionary size and make room in the cache
|
||||
new_sd_size = sum([ten.nelement() * torch.bfloat16.itemsize for ten in sd.values()])
|
||||
self._ram_cache.make_room(new_sd_size)
|
||||
|
||||
# Load the state dict into the model
|
||||
model.load_state_dict(sd, assign=True)
|
||||
|
||||
# Return the model after patching
|
||||
return model
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.BnbQuantizednf4b)
|
||||
class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
|
||||
"""Class to load main models."""
|
||||
|
||||
@@ -30,6 +30,8 @@ from invokeai.backend.model_manager.config import (
|
||||
SchedulerPredictionType,
|
||||
)
|
||||
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length, read_checkpoint_meta
|
||||
from invokeai.backend.quantization.gguf.layers import GGUFTensor
|
||||
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
|
||||
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
@@ -187,6 +189,7 @@ class ModelProbe(object):
|
||||
if fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE] and fields["format"] in [
|
||||
ModelFormat.Checkpoint,
|
||||
ModelFormat.BnbQuantizednf4b,
|
||||
ModelFormat.GGUFQuantized,
|
||||
]:
|
||||
ckpt_config_path = cls._get_checkpoint_config_path(
|
||||
model_path,
|
||||
@@ -220,7 +223,7 @@ class ModelProbe(object):
|
||||
|
||||
@classmethod
|
||||
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: Optional[CkptType] = None) -> ModelType:
|
||||
if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth"):
|
||||
if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth", ".gguf"):
|
||||
raise InvalidModelConfigException(f"{model_path}: unrecognized suffix")
|
||||
|
||||
if model_path.name == "learned_embeds.bin":
|
||||
@@ -408,6 +411,8 @@ class ModelProbe(object):
|
||||
model = torch.load(model_path, map_location="cpu")
|
||||
assert isinstance(model, dict)
|
||||
return model
|
||||
elif model_path.suffix.endswith(".gguf"):
|
||||
return gguf_sd_loader(model_path)
|
||||
else:
|
||||
return safetensors.torch.load_file(model_path)
|
||||
|
||||
@@ -477,6 +482,8 @@ class CheckpointProbeBase(ProbeBase):
|
||||
or "model.diffusion_model.double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4" in state_dict
|
||||
):
|
||||
return ModelFormat.BnbQuantizednf4b
|
||||
elif any(isinstance(v, GGUFTensor) for v in state_dict.values()):
|
||||
return ModelFormat.GGUFQuantized
|
||||
return ModelFormat("checkpoint")
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
|
||||
@@ -8,6 +8,8 @@ import safetensors
|
||||
import torch
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
|
||||
|
||||
|
||||
def _fast_safetensors_reader(path: str) -> Dict[str, torch.Tensor]:
|
||||
checkpoint = {}
|
||||
@@ -54,7 +56,10 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = False) -> Dict[str
|
||||
scan_result = scan_file_path(path)
|
||||
if scan_result.infected_files != 0:
|
||||
raise Exception(f'The model file "{path}" is potentially infected by malware. Aborting import.')
|
||||
checkpoint = torch.load(path, map_location=torch.device("meta"))
|
||||
if str(path).endswith(".gguf"):
|
||||
checkpoint = gguf_sd_loader(Path(path))
|
||||
else:
|
||||
checkpoint = torch.load(path, map_location=torch.device("meta"))
|
||||
return checkpoint
|
||||
|
||||
|
||||
|
||||
151
invokeai/backend/quantization/gguf/layers.py
Normal file
151
invokeai/backend/quantization/gguf/layers.py
Normal file
@@ -0,0 +1,151 @@
|
||||
# Largely based on https://github.com/city96/ComfyUI-GGUF
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import gguf
|
||||
from torch import Tensor, device, dtype, float32, nn, zeros_like
|
||||
|
||||
from invokeai.backend.quantization.gguf.utils import dequantize_tensor, is_quantized
|
||||
|
||||
PATCH_TYPES = Union[list[Tensor], tuple[Tensor]]
|
||||
|
||||
|
||||
class GGUFTensor(Tensor):
|
||||
"""
|
||||
Main tensor-like class for storing quantized weights
|
||||
"""
|
||||
|
||||
def __init__(self, *args, tensor_type, tensor_shape, patches=None, **kwargs):
|
||||
super().__init__()
|
||||
self.tensor_type = tensor_type
|
||||
self.tensor_shape = tensor_shape
|
||||
self.patches = patches or []
|
||||
|
||||
def __new__(cls, *args, tensor_type, tensor_shape, patches=None, **kwargs):
|
||||
return super().__new__(cls, *args, **kwargs)
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
new = super().to(*args, **kwargs)
|
||||
new.tensor_type = getattr(self, "tensor_type", None)
|
||||
new.tensor_shape = getattr(self, "tensor_shape", new.data.shape)
|
||||
new.patches = getattr(self, "patches", []).copy()
|
||||
return new
|
||||
|
||||
def clone(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def detach(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def copy_(self, *args, **kwargs):
|
||||
# fixes .weight.copy_ in comfy/clip_model/CLIPTextModel
|
||||
try:
|
||||
return super().copy_(*args, **kwargs)
|
||||
except Exception as e:
|
||||
print(f"ignoring 'copy_' on tensor: {e}")
|
||||
|
||||
def __deepcopy__(self, *args, **kwargs):
|
||||
# Intel Arc fix, ref#50
|
||||
new = super().__deepcopy__(*args, **kwargs)
|
||||
new.tensor_type = getattr(self, "tensor_type", None)
|
||||
new.tensor_shape = getattr(self, "tensor_shape", new.data.shape)
|
||||
new.patches = getattr(self, "patches", []).copy()
|
||||
return new
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
if not hasattr(self, "tensor_shape"):
|
||||
self.tensor_shape = self.size()
|
||||
return self.tensor_shape
|
||||
|
||||
|
||||
class GGUFLayer(nn.Module):
|
||||
"""
|
||||
This (should) be responsible for de-quantizing on the fly
|
||||
"""
|
||||
|
||||
dequant_dtype = None
|
||||
patch_dtype = None
|
||||
torch_compatible_tensor_types = {None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}
|
||||
|
||||
def is_ggml_quantized(self, *, weight: Optional[Tensor] = None, bias: Optional[Tensor] = None):
|
||||
if weight is None or bias is None:
|
||||
return False
|
||||
return is_quantized(weight) or is_quantized(bias)
|
||||
|
||||
def _load_from_state_dict(self, state_dict: dict[str, Tensor], prefix: str, *args, **kwargs):
|
||||
weight, bias = state_dict.get(f"{prefix}weight", None), state_dict.get(f"{prefix}bias", None)
|
||||
if self.is_ggml_quantized(weight=weight, bias=bias):
|
||||
return self.ggml_load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
||||
return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
||||
|
||||
def ggml_load_from_state_dict(
|
||||
self,
|
||||
state_dict: dict[str, Tensor],
|
||||
prefix: str,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys: list[str],
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
):
|
||||
for k, v in state_dict.items():
|
||||
if k.endswith("weight"):
|
||||
self.weight = nn.Parameter(v, requires_grad=False)
|
||||
elif k.endswith("bias") and v is not None:
|
||||
self.bias = nn.Parameter(v, requires_grad=False)
|
||||
else:
|
||||
missing_keys.append(k)
|
||||
|
||||
def _save_to_state_dict(self, *args, **kwargs):
|
||||
if self.is_ggml_quantized():
|
||||
return self.ggml_save_to_state_dict(*args, **kwargs)
|
||||
return super()._save_to_state_dict(*args, **kwargs)
|
||||
|
||||
def ggml_save_to_state_dict(self, destination: dict[str, Tensor], prefix: str):
|
||||
# This is a fake state dict for vram estimation
|
||||
weight = zeros_like(self.weight, device=device("meta"))
|
||||
destination[prefix + "weight"] = weight
|
||||
if self.bias is not None:
|
||||
bias = zeros_like(self.bias, device=device("meta"))
|
||||
destination[prefix + "bias"] = bias
|
||||
return
|
||||
|
||||
def get_weight(self, tensor: Optional[Tensor], dtype: dtype):
|
||||
if tensor is None:
|
||||
return
|
||||
|
||||
# dequantize tensor while patches load
|
||||
weight = dequantize_tensor(tensor, dtype, self.dequant_dtype)
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
"""Get the size of this model in bytes."""
|
||||
return self.bias.nelement() * self.bias.element_size()
|
||||
|
||||
def cast_bias_weight(
|
||||
self,
|
||||
input: Tensor,
|
||||
dtype: Optional[dtype] = None,
|
||||
device: Optional[device] = None,
|
||||
bias_dtype: Optional[dtype] = None,
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
if dtype is None:
|
||||
dtype = getattr(input, "dtype", float32)
|
||||
if dtype is None:
|
||||
raise ValueError("dtype is required")
|
||||
if bias_dtype is None:
|
||||
bias_dtype = dtype
|
||||
if device is None:
|
||||
device = input.device
|
||||
|
||||
bias = self.get_weight(self.bias.to(device), dtype)
|
||||
if bias is not None:
|
||||
bias = bias.to(dtype=bias_dtype, device=device, copy=False)
|
||||
|
||||
weight = self.get_weight(self.weight.to(device), dtype)
|
||||
if weight is not None:
|
||||
weight = weight.to(dtype=dtype, device=device)
|
||||
if weight is None or bias is None:
|
||||
raise ValueError("Weight or bias is None")
|
||||
return weight, bias
|
||||
68
invokeai/backend/quantization/gguf/loaders.py
Normal file
68
invokeai/backend/quantization/gguf/loaders.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# Largely based on https://github.com/city96/ComfyUI-GGUF
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import gguf
|
||||
import torch
|
||||
|
||||
from invokeai.backend.quantization.gguf.layers import GGUFTensor
|
||||
from invokeai.backend.quantization.gguf.utils import detect_arch
|
||||
|
||||
|
||||
def gguf_sd_loader(
|
||||
path: Path, handle_prefix: str = "model.diffusion_model.", data_type: torch.dtype = torch.bfloat16
|
||||
) -> dict[str, GGUFTensor]:
|
||||
"""
|
||||
Read state dict as fake tensors
|
||||
"""
|
||||
reader = gguf.GGUFReader(path)
|
||||
|
||||
prefix_len = len(handle_prefix)
|
||||
tensor_names = {tensor.name for tensor in reader.tensors}
|
||||
has_prefix = any(s.startswith(handle_prefix) for s in tensor_names)
|
||||
|
||||
tensors: list[tuple[str, gguf.ReaderTensor]] = []
|
||||
for tensor in reader.tensors:
|
||||
sd_key = tensor_name = tensor.name
|
||||
if has_prefix:
|
||||
if not tensor_name.startswith(handle_prefix):
|
||||
continue
|
||||
sd_key = tensor_name[prefix_len:]
|
||||
tensors.append((sd_key, tensor))
|
||||
|
||||
# detect and verify architecture
|
||||
compat = None
|
||||
arch_str = None
|
||||
arch_field = reader.get_field("general.architecture")
|
||||
if arch_field is not None:
|
||||
if len(arch_field.types) != 1 or arch_field.types[0] != gguf.GGUFValueType.STRING:
|
||||
raise TypeError(f"Bad type for GGUF general.architecture key: expected string, got {arch_field.types!r}")
|
||||
arch_str = str(arch_field.parts[arch_field.data[-1]], encoding="utf-8")
|
||||
if arch_str not in {"flux"}:
|
||||
raise ValueError(f"Unexpected architecture type in GGUF file, expected flux, but got {arch_str!r}")
|
||||
else:
|
||||
arch_str = detect_arch({val[0] for val in tensors})
|
||||
compat = "sd.cpp"
|
||||
|
||||
# main loading loop
|
||||
state_dict: dict[str, GGUFTensor] = {}
|
||||
qtype_dict: dict[str, int] = {}
|
||||
for sd_key, tensor in tensors:
|
||||
tensor_name = tensor.name
|
||||
tensor_type_str = str(tensor.tensor_type)
|
||||
torch_tensor = torch.from_numpy(tensor.data) # mmap
|
||||
|
||||
shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape)))
|
||||
# Workaround for stable-diffusion.cpp SDXL detection.
|
||||
if compat == "sd.cpp" and arch_str == "sdxl":
|
||||
if tensor_name.endswith((".proj_in.weight", ".proj_out.weight")):
|
||||
while len(shape) > 2 and shape[-1] == 1:
|
||||
shape = shape[:-1]
|
||||
|
||||
# add to state dict
|
||||
if tensor.tensor_type in {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}:
|
||||
torch_tensor = torch_tensor.view(*shape)
|
||||
state_dict[sd_key] = GGUFTensor(torch_tensor, tensor_type=tensor.tensor_type, tensor_shape=shape)
|
||||
qtype_dict[tensor_type_str] = qtype_dict.get(tensor_type_str, 0) + 1
|
||||
|
||||
return state_dict
|
||||
90
invokeai/backend/quantization/gguf/torch_patcher.py
Normal file
90
invokeai/backend/quantization/gguf/torch_patcher.py
Normal file
@@ -0,0 +1,90 @@
|
||||
# Largely based on https://github.com/city96/ComfyUI-GGUF
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator, Optional
|
||||
|
||||
import wrapt
|
||||
from torch import Tensor, bfloat16, dtype, float16, nn
|
||||
|
||||
from invokeai.backend.quantization.gguf.layers import GGUFLayer
|
||||
|
||||
|
||||
class TorchPatcher:
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def wrap(cls) -> Generator[None, None, None]:
|
||||
# Dictionary to store original torch.nn classes for later restoration
|
||||
original_classes = {}
|
||||
try:
|
||||
# Iterate over _patcher's attributes and replace matching torch.nn classes
|
||||
for attr_name in dir(cls):
|
||||
if attr_name.startswith("__"):
|
||||
continue
|
||||
# Get the class from _patcher
|
||||
patcher_class = getattr(cls, attr_name)
|
||||
|
||||
# Check if torch.nn has a class with the same name
|
||||
if hasattr(nn, attr_name):
|
||||
# Get the original torch.nn class
|
||||
original_class = getattr(nn, attr_name)
|
||||
|
||||
# Define a helper function to bind the current patcher_attr for each iteration
|
||||
def create_patch_function(patcher_attr):
|
||||
# Return a new patch_class function specific to this patcher_attr
|
||||
@wrapt.decorator
|
||||
def patch_class(wrapped, instance, args, kwargs):
|
||||
# Call the _patcher version of the class
|
||||
return patcher_attr(*args, **kwargs)
|
||||
|
||||
return patch_class
|
||||
|
||||
# Save the original class for restoration later
|
||||
original_classes[attr_name] = original_class
|
||||
|
||||
# Apply the patch
|
||||
setattr(nn, attr_name, create_patch_function(patcher_class)(original_class))
|
||||
yield
|
||||
finally:
|
||||
# Restore the original torch.nn classes
|
||||
for attr_name, original_class in original_classes.items():
|
||||
setattr(nn, attr_name, original_class)
|
||||
|
||||
|
||||
class GGUFPatcher(TorchPatcher):
|
||||
"""
|
||||
Dequantize weights on the fly before doing the compute
|
||||
"""
|
||||
|
||||
class Linear(GGUFLayer, nn.Linear):
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
weight, bias = self.cast_bias_weight(input)
|
||||
return nn.functional.linear(input, weight, bias)
|
||||
|
||||
class Conv2d(GGUFLayer, nn.Conv2d):
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
weight, bias = self.cast_bias_weight(input)
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
class Embedding(GGUFLayer, nn.Embedding):
|
||||
def forward(self, input: Tensor, out_dtype: Optional[dtype] = None) -> Tensor:
|
||||
output_dtype = out_dtype
|
||||
if not self.weight:
|
||||
raise ValueError("Embedding layer must have a weight")
|
||||
if self.weight.dtype == float16 or self.weight.dtype == bfloat16:
|
||||
out_dtype = None
|
||||
weight, _ = self.cast_bias_weight(input, device=input.device, dtype=out_dtype)
|
||||
return nn.functional.embedding(
|
||||
input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse
|
||||
).to(dtype=output_dtype)
|
||||
|
||||
class LayerNorm(GGUFLayer, nn.LayerNorm):
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
if self.weight is None:
|
||||
return nn.functional.layer_norm(input, self.normalized_shape, None, None, self.eps)
|
||||
weight, bias = self.cast_bias_weight(input)
|
||||
return nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
|
||||
|
||||
class GroupNorm(GGUFLayer, nn.GroupNorm):
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
weight, bias = self.cast_bias_weight(input)
|
||||
return nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
||||
361
invokeai/backend/quantization/gguf/utils.py
Normal file
361
invokeai/backend/quantization/gguf/utils.py
Normal file
@@ -0,0 +1,361 @@
|
||||
# Largely based on https://github.com/city96/ComfyUI-GGUF
|
||||
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import gguf
|
||||
import torch
|
||||
|
||||
TORCH_COMPATIBLE_QTYPES = {None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}
|
||||
|
||||
# K Quants #
|
||||
QK_K = 256
|
||||
K_SCALE_SIZE = 12
|
||||
|
||||
MODEL_DETECTION = (
|
||||
(
|
||||
"flux",
|
||||
(
|
||||
("transformer_blocks.0.attn.norm_added_k.weight",),
|
||||
("double_blocks.0.img_attn.proj.weight",),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def get_scale_min(scales: torch.Tensor):
|
||||
n_blocks = scales.shape[0]
|
||||
scales = scales.view(torch.uint8)
|
||||
scales = scales.reshape((n_blocks, 3, 4))
|
||||
|
||||
d, m, m_d = torch.split(scales, scales.shape[-2] // 3, dim=-2)
|
||||
|
||||
sc = torch.cat([d & 0x3F, (m_d & 0x0F) | ((d >> 2) & 0x30)], dim=-1)
|
||||
min = torch.cat([m & 0x3F, (m_d >> 4) | ((m >> 2) & 0x30)], dim=-1)
|
||||
|
||||
return (sc.reshape((n_blocks, 8)), min.reshape((n_blocks, 8)))
|
||||
|
||||
|
||||
# Legacy Quants #
|
||||
def dequantize_blocks_Q8_0(
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
||||
) -> torch.Tensor:
|
||||
d, x = split_block_dims(blocks, 2)
|
||||
d = d.view(torch.float16).to(dtype)
|
||||
x = x.view(torch.int8)
|
||||
return d * x
|
||||
|
||||
|
||||
def dequantize_blocks_Q5_1(
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
||||
) -> torch.Tensor:
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
d, m, qh, qs = split_block_dims(blocks, 2, 2, 4)
|
||||
d = d.view(torch.float16).to(dtype)
|
||||
m = m.view(torch.float16).to(dtype)
|
||||
qh = to_uint32(qh)
|
||||
|
||||
qh = qh.reshape((n_blocks, 1)) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
|
||||
ql = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor(
|
||||
[0, 4], device=d.device, dtype=torch.uint8
|
||||
).reshape(1, 1, 2, 1)
|
||||
qh = (qh & 1).to(torch.uint8)
|
||||
ql = (ql & 0x0F).reshape((n_blocks, -1))
|
||||
|
||||
qs = ql | (qh << 4)
|
||||
return (d * qs) + m
|
||||
|
||||
|
||||
def dequantize_blocks_Q5_0(
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
||||
) -> torch.Tensor:
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
d, qh, qs = split_block_dims(blocks, 2, 4)
|
||||
d = d.view(torch.float16).to(dtype)
|
||||
qh = to_uint32(qh)
|
||||
|
||||
qh = qh.reshape(n_blocks, 1) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
|
||||
ql = qs.reshape(n_blocks, -1, 1, block_size // 2) >> torch.tensor(
|
||||
[0, 4], device=d.device, dtype=torch.uint8
|
||||
).reshape(1, 1, 2, 1)
|
||||
|
||||
qh = (qh & 1).to(torch.uint8)
|
||||
ql = (ql & 0x0F).reshape(n_blocks, -1)
|
||||
|
||||
qs = (ql | (qh << 4)).to(torch.int8) - 16
|
||||
return d * qs
|
||||
|
||||
|
||||
def dequantize_blocks_Q4_1(
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
||||
) -> torch.Tensor:
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
d, m, qs = split_block_dims(blocks, 2, 2)
|
||||
d = d.view(torch.float16).to(dtype)
|
||||
m = m.view(torch.float16).to(dtype)
|
||||
|
||||
qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor(
|
||||
[0, 4], device=d.device, dtype=torch.uint8
|
||||
).reshape(1, 1, 2, 1)
|
||||
qs = (qs & 0x0F).reshape(n_blocks, -1)
|
||||
|
||||
return (d * qs) + m
|
||||
|
||||
|
||||
def dequantize_blocks_Q4_0(
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
||||
) -> torch.Tensor:
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
d, qs = split_block_dims(blocks, 2)
|
||||
d = d.view(torch.float16).to(dtype)
|
||||
|
||||
qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor(
|
||||
[0, 4], device=d.device, dtype=torch.uint8
|
||||
).reshape((1, 1, 2, 1))
|
||||
qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8
|
||||
return d * qs
|
||||
|
||||
|
||||
def dequantize_blocks_BF16(
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
||||
) -> torch.Tensor:
|
||||
return (blocks.view(torch.int16).to(torch.int32) << 16).view(torch.float32)
|
||||
|
||||
|
||||
def dequantize_blocks_Q6_K(
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
||||
) -> torch.Tensor:
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
(
|
||||
ql,
|
||||
qh,
|
||||
scales,
|
||||
d,
|
||||
) = split_block_dims(blocks, QK_K // 2, QK_K // 4, QK_K // 16)
|
||||
|
||||
scales = scales.view(torch.int8).to(dtype)
|
||||
d = d.view(torch.float16).to(dtype)
|
||||
d = (d * scales).reshape((n_blocks, QK_K // 16, 1))
|
||||
|
||||
ql = ql.reshape((n_blocks, -1, 1, 64)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(
|
||||
(1, 1, 2, 1)
|
||||
)
|
||||
ql = (ql & 0x0F).reshape((n_blocks, -1, 32))
|
||||
qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape(
|
||||
(1, 1, 4, 1)
|
||||
)
|
||||
qh = (qh & 0x03).reshape((n_blocks, -1, 32))
|
||||
q = (ql | (qh << 4)).to(torch.int8) - 32
|
||||
q = q.reshape((n_blocks, QK_K // 16, -1))
|
||||
|
||||
return (d * q).reshape((n_blocks, QK_K))
|
||||
|
||||
|
||||
def dequantize_blocks_Q5_K(
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
||||
) -> torch.Tensor:
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
d, dmin, scales, qh, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE, QK_K // 8)
|
||||
|
||||
d = d.view(torch.float16).to(dtype)
|
||||
dmin = dmin.view(torch.float16).to(dtype)
|
||||
|
||||
sc, m = get_scale_min(scales)
|
||||
|
||||
d = (d * sc).reshape((n_blocks, -1, 1))
|
||||
dm = (dmin * m).reshape((n_blocks, -1, 1))
|
||||
|
||||
ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(
|
||||
(1, 1, 2, 1)
|
||||
)
|
||||
qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.tensor(list(range(8)), device=d.device, dtype=torch.uint8).reshape(
|
||||
(1, 1, 8, 1)
|
||||
)
|
||||
ql = (ql & 0x0F).reshape((n_blocks, -1, 32))
|
||||
qh = (qh & 0x01).reshape((n_blocks, -1, 32))
|
||||
q = ql | (qh << 4)
|
||||
|
||||
return (d * q - dm).reshape((n_blocks, QK_K))
|
||||
|
||||
|
||||
def dequantize_blocks_Q4_K(
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
||||
) -> torch.Tensor:
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
d, dmin, scales, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE)
|
||||
d = d.view(torch.float16).to(dtype)
|
||||
dmin = dmin.view(torch.float16).to(dtype)
|
||||
|
||||
sc, m = get_scale_min(scales)
|
||||
|
||||
d = (d * sc).reshape((n_blocks, -1, 1))
|
||||
dm = (dmin * m).reshape((n_blocks, -1, 1))
|
||||
|
||||
qs = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(
|
||||
(1, 1, 2, 1)
|
||||
)
|
||||
qs = (qs & 0x0F).reshape((n_blocks, -1, 32))
|
||||
|
||||
return (d * qs - dm).reshape((n_blocks, QK_K))
|
||||
|
||||
|
||||
def dequantize_blocks_Q3_K(
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
||||
) -> torch.Tensor:
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
hmask, qs, scales, d = split_block_dims(blocks, QK_K // 8, QK_K // 4, 12)
|
||||
d = d.view(torch.float16).to(dtype)
|
||||
|
||||
lscales, hscales = scales[:, :8], scales[:, 8:]
|
||||
lscales = lscales.reshape((n_blocks, 1, 8)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(
|
||||
(1, 2, 1)
|
||||
)
|
||||
lscales = lscales.reshape((n_blocks, 16))
|
||||
hscales = hscales.reshape((n_blocks, 1, 4)) >> torch.tensor(
|
||||
[0, 2, 4, 6], device=d.device, dtype=torch.uint8
|
||||
).reshape((1, 4, 1))
|
||||
hscales = hscales.reshape((n_blocks, 16))
|
||||
scales = (lscales & 0x0F) | ((hscales & 0x03) << 4)
|
||||
scales = scales.to(torch.int8) - 32
|
||||
|
||||
dl = (d * scales).reshape((n_blocks, 16, 1))
|
||||
|
||||
ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape(
|
||||
(1, 1, 4, 1)
|
||||
)
|
||||
qh = hmask.reshape(n_blocks, -1, 1, 32) >> torch.tensor(list(range(8)), device=d.device, dtype=torch.uint8).reshape(
|
||||
(1, 1, 8, 1)
|
||||
)
|
||||
ql = ql.reshape((n_blocks, 16, QK_K // 16)) & 3
|
||||
qh = (qh.reshape((n_blocks, 16, QK_K // 16)) & 1) ^ 1
|
||||
q = ql.to(torch.int8) - (qh << 2).to(torch.int8)
|
||||
|
||||
return (dl * q).reshape((n_blocks, QK_K))
|
||||
|
||||
|
||||
def dequantize_blocks_Q2_K(
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
||||
) -> torch.Tensor:
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
scales, qs, d, dmin = split_block_dims(blocks, QK_K // 16, QK_K // 4, 2)
|
||||
d = d.view(torch.float16).to(dtype)
|
||||
dmin = dmin.view(torch.float16).to(dtype)
|
||||
|
||||
# (n_blocks, 16, 1)
|
||||
dl = (d * (scales & 0xF)).reshape((n_blocks, QK_K // 16, 1))
|
||||
ml = (dmin * (scales >> 4)).reshape((n_blocks, QK_K // 16, 1))
|
||||
|
||||
shift = torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 1, 4, 1))
|
||||
|
||||
qs = (qs.reshape((n_blocks, -1, 1, 32)) >> shift) & 3
|
||||
qs = qs.reshape((n_blocks, QK_K // 16, 16))
|
||||
qs = dl * qs - ml
|
||||
|
||||
return qs.reshape((n_blocks, -1))
|
||||
|
||||
|
||||
DEQUANTIZE_FUNCTIONS: dict[
|
||||
gguf.GGMLQuantizationType, Callable[[torch.Tensor, int, int, Optional[torch.dtype]], torch.Tensor]
|
||||
] = {
|
||||
gguf.GGMLQuantizationType.BF16: dequantize_blocks_BF16,
|
||||
gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0,
|
||||
gguf.GGMLQuantizationType.Q5_1: dequantize_blocks_Q5_1,
|
||||
gguf.GGMLQuantizationType.Q5_0: dequantize_blocks_Q5_0,
|
||||
gguf.GGMLQuantizationType.Q4_1: dequantize_blocks_Q4_1,
|
||||
gguf.GGMLQuantizationType.Q4_0: dequantize_blocks_Q4_0,
|
||||
gguf.GGMLQuantizationType.Q6_K: dequantize_blocks_Q6_K,
|
||||
gguf.GGMLQuantizationType.Q5_K: dequantize_blocks_Q5_K,
|
||||
gguf.GGMLQuantizationType.Q4_K: dequantize_blocks_Q4_K,
|
||||
gguf.GGMLQuantizationType.Q3_K: dequantize_blocks_Q3_K,
|
||||
gguf.GGMLQuantizationType.Q2_K: dequantize_blocks_Q2_K,
|
||||
}
|
||||
|
||||
|
||||
def is_torch_compatible(tensor: Optional[torch.Tensor]):
|
||||
return getattr(tensor, "tensor_type", None) in TORCH_COMPATIBLE_QTYPES
|
||||
|
||||
|
||||
def is_quantized(tensor: torch.Tensor):
|
||||
return not is_torch_compatible(tensor)
|
||||
|
||||
|
||||
def dequantize_tensor(
|
||||
tensor: torch.Tensor, dtype: torch.dtype, dequant_dtype: Union[torch.dtype, str, None] = None
|
||||
) -> torch.Tensor:
|
||||
qtype: Optional[gguf.GGMLQuantizationType] = getattr(tensor, "tensor_type", None)
|
||||
oshape: torch.Size = getattr(tensor, "tensor_shape", tensor.shape)
|
||||
if qtype is None:
|
||||
raise ValueError("This is not a valid quantized tensor")
|
||||
if qtype in TORCH_COMPATIBLE_QTYPES:
|
||||
return tensor.to(dtype)
|
||||
elif qtype in DEQUANTIZE_FUNCTIONS:
|
||||
dequant_dtype = dtype if dequant_dtype == "target" else dequant_dtype
|
||||
return dequantize(tensor.data, qtype, oshape, dtype=dequant_dtype).to(dtype)
|
||||
else:
|
||||
new = gguf.quants.dequantize(tensor.cpu().numpy(), qtype)
|
||||
return torch.from_numpy(new).to(tensor.device, dtype=dtype)
|
||||
|
||||
|
||||
def dequantize(
|
||||
data: torch.Tensor, qtype: gguf.GGMLQuantizationType, oshape: torch.Size, dtype: Optional[torch.dtype] = None
|
||||
):
|
||||
"""
|
||||
Dequantize tensor back to usable shape/dtype
|
||||
"""
|
||||
block_size, type_size = gguf.GGML_QUANT_SIZES[qtype]
|
||||
dequantize_blocks = DEQUANTIZE_FUNCTIONS[qtype]
|
||||
|
||||
rows = data.reshape((-1, data.shape[-1])).view(torch.uint8)
|
||||
|
||||
n_blocks = rows.numel() // type_size
|
||||
blocks = rows.reshape((n_blocks, type_size))
|
||||
blocks = dequantize_blocks(blocks, block_size, type_size, dtype)
|
||||
return blocks.reshape(oshape)
|
||||
|
||||
|
||||
def to_uint32(x: torch.Tensor) -> torch.Tensor:
|
||||
x = x.view(torch.uint8).to(torch.int32)
|
||||
return (x[:, 0] | x[:, 1] << 8 | x[:, 2] << 16 | x[:, 3] << 24).unsqueeze(1)
|
||||
|
||||
|
||||
def split_block_dims(blocks: torch.Tensor, *args):
|
||||
n_max = blocks.shape[1]
|
||||
dims = list(args) + [n_max - sum(args)]
|
||||
return torch.split(blocks, dims, dim=1)
|
||||
|
||||
|
||||
PATCH_TYPES = Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]]
|
||||
|
||||
|
||||
def move_patch_to_device(item: PATCH_TYPES, device: torch.device) -> PATCH_TYPES:
|
||||
if isinstance(item, torch.Tensor):
|
||||
return item.to(device, non_blocking=True)
|
||||
elif isinstance(item, tuple):
|
||||
if len(item) == 0:
|
||||
return item
|
||||
if not isinstance(item[0], torch.Tensor):
|
||||
raise ValueError("Invalid item")
|
||||
return tuple(move_patch_to_device(x, device) for x in item)
|
||||
elif isinstance(item, list):
|
||||
if len(item) == 0:
|
||||
return item
|
||||
if not isinstance(item[0], torch.Tensor):
|
||||
raise ValueError("Invalid item")
|
||||
return [move_patch_to_device(x, device) for x in item]
|
||||
|
||||
|
||||
def detect_arch(state_dict: dict[str, torch.Tensor]):
|
||||
for arch, match_lists in MODEL_DETECTION:
|
||||
for match_list in match_lists:
|
||||
if all(key in state_dict for key in match_list):
|
||||
return arch
|
||||
breakpoint()
|
||||
raise ValueError("Unknown model architecture!")
|
||||
Reference in New Issue
Block a user