Add workaround for FLUX GGUF models with incorrect img_in.weight shape.

This commit is contained in:
Ryan Dick
2024-10-01 21:08:12 +00:00
committed by Kent Keirsey
parent ec7e771942
commit bc63e2acc5
2 changed files with 20 additions and 9 deletions

View File

@@ -37,6 +37,7 @@ from invokeai.backend.model_manager.util.model_util import (
convert_bundle_to_flux_transformer_checkpoint,
)
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
from invokeai.backend.quantization.gguf.utils import TORCH_COMPATIBLE_QTYPES
from invokeai.backend.util.silence_warnings import SilenceWarnings
try:
@@ -234,11 +235,21 @@ class FluxGGUFCheckpointModel(ModelLoader):
model_path = Path(config.path)
with SilenceWarnings():
# Load the state dict and patcher
model = Flux(params[config.config_path])
# HACK(ryand): We shouldn't be hard-coding the compute_dtype here.
sd = gguf_sd_loader(model_path, compute_dtype=torch.bfloat16)
model = Flux(params[config.config_path])
model.load_state_dict(sd, assign=True)
# HACK(ryand): There are some broken GGUF models in circulation that have the wrong shape for img_in.weight.
# We override the shape here to fix the issue.
# Example model with this issue (Q4_K_M): https://civitai.com/models/705823/ggufk-flux-unchained-km-quants
img_in_weight = sd.get("img_in.weight", None)
if img_in_weight is not None and img_in_weight._ggml_quantization_type in TORCH_COMPATIBLE_QTYPES:
expected_img_in_weight_shape = model.img_in.weight.shape
img_in_weight.quantized_data = img_in_weight.quantized_data.view(expected_img_in_weight_shape)
img_in_weight.tensor_shape = expected_img_in_weight_shape
model.load_state_dict(sd, assign=True)
return model

View File

@@ -40,7 +40,7 @@ def apply_to_quantized_tensor(func, args, kwargs):
raise ValueError("Operation changed the dtype of GGMLTensor unexpectedly.")
return GGMLTensor(
new_data, ggml_tensor._ggml_quantization_type, ggml_tensor._tensor_shape, ggml_tensor.compute_dtype
new_data, ggml_tensor._ggml_quantization_type, ggml_tensor.tensor_shape, ggml_tensor.compute_dtype
)
@@ -91,11 +91,11 @@ class GGMLTensor(torch.Tensor):
self.quantized_data = data
self._ggml_quantization_type = ggml_quantization_type
# The dequantized shape of the tensor.
self._tensor_shape = tensor_shape
self.tensor_shape = tensor_shape
self.compute_dtype = compute_dtype
def __repr__(self, *, tensor_contents=None):
return f"GGMLTensor(type={self._ggml_quantization_type.name}, dequantized_shape=({self._tensor_shape})"
return f"GGMLTensor(type={self._ggml_quantization_type.name}, dequantized_shape=({self.tensor_shape})"
@overload
def size(self, dim: None = None) -> torch.Size: ...
@@ -106,8 +106,8 @@ class GGMLTensor(torch.Tensor):
def size(self, dim: int | None = None):
"""Return the size of the tensor after dequantization. I.e. the shape that will be used in any math ops."""
if dim is not None:
return self._tensor_shape[dim]
return self._tensor_shape
return self.tensor_shape[dim]
return self.tensor_shape
@property
def shape(self) -> torch.Size: # pyright: ignore[reportIncompatibleVariableOverride] pyright doesn't understand this for some reason.
@@ -136,7 +136,7 @@ class GGMLTensor(torch.Tensor):
elif self._ggml_quantization_type in DEQUANTIZE_FUNCTIONS:
# TODO(ryand): Look into how the dtype param is intended to be used.
return dequantize(
data=self.quantized_data, qtype=self._ggml_quantization_type, oshape=self._tensor_shape, dtype=None
data=self.quantized_data, qtype=self._ggml_quantization_type, oshape=self.tensor_shape, dtype=None
).to(self.compute_dtype)
else:
# There is no GPU implementation for this quantization type, so fallback to the numpy implementation.