From bc63e2acc5ea45db7daa7febe7238e4673e79fe1 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 1 Oct 2024 21:08:12 +0000 Subject: [PATCH] Add workaround for FLUX GGUF models with incorrect img_in.weight shape. --- .../model_manager/load/model_loaders/flux.py | 17 ++++++++++++++--- .../backend/quantization/gguf/ggml_tensor.py | 12 ++++++------ 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/invokeai/backend/model_manager/load/model_loaders/flux.py b/invokeai/backend/model_manager/load/model_loaders/flux.py index cffa426a0a..c461f6b621 100644 --- a/invokeai/backend/model_manager/load/model_loaders/flux.py +++ b/invokeai/backend/model_manager/load/model_loaders/flux.py @@ -37,6 +37,7 @@ from invokeai.backend.model_manager.util.model_util import ( convert_bundle_to_flux_transformer_checkpoint, ) from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader +from invokeai.backend.quantization.gguf.utils import TORCH_COMPATIBLE_QTYPES from invokeai.backend.util.silence_warnings import SilenceWarnings try: @@ -234,11 +235,21 @@ class FluxGGUFCheckpointModel(ModelLoader): model_path = Path(config.path) with SilenceWarnings(): - # Load the state dict and patcher + model = Flux(params[config.config_path]) + # HACK(ryand): We shouldn't be hard-coding the compute_dtype here. sd = gguf_sd_loader(model_path, compute_dtype=torch.bfloat16) - model = Flux(params[config.config_path]) - model.load_state_dict(sd, assign=True) + + # HACK(ryand): There are some broken GGUF models in circulation that have the wrong shape for img_in.weight. + # We override the shape here to fix the issue. + # Example model with this issue (Q4_K_M): https://civitai.com/models/705823/ggufk-flux-unchained-km-quants + img_in_weight = sd.get("img_in.weight", None) + if img_in_weight is not None and img_in_weight._ggml_quantization_type in TORCH_COMPATIBLE_QTYPES: + expected_img_in_weight_shape = model.img_in.weight.shape + img_in_weight.quantized_data = img_in_weight.quantized_data.view(expected_img_in_weight_shape) + img_in_weight.tensor_shape = expected_img_in_weight_shape + + model.load_state_dict(sd, assign=True) return model diff --git a/invokeai/backend/quantization/gguf/ggml_tensor.py b/invokeai/backend/quantization/gguf/ggml_tensor.py index eaddcb7ea0..b9504318c1 100644 --- a/invokeai/backend/quantization/gguf/ggml_tensor.py +++ b/invokeai/backend/quantization/gguf/ggml_tensor.py @@ -40,7 +40,7 @@ def apply_to_quantized_tensor(func, args, kwargs): raise ValueError("Operation changed the dtype of GGMLTensor unexpectedly.") return GGMLTensor( - new_data, ggml_tensor._ggml_quantization_type, ggml_tensor._tensor_shape, ggml_tensor.compute_dtype + new_data, ggml_tensor._ggml_quantization_type, ggml_tensor.tensor_shape, ggml_tensor.compute_dtype ) @@ -91,11 +91,11 @@ class GGMLTensor(torch.Tensor): self.quantized_data = data self._ggml_quantization_type = ggml_quantization_type # The dequantized shape of the tensor. - self._tensor_shape = tensor_shape + self.tensor_shape = tensor_shape self.compute_dtype = compute_dtype def __repr__(self, *, tensor_contents=None): - return f"GGMLTensor(type={self._ggml_quantization_type.name}, dequantized_shape=({self._tensor_shape})" + return f"GGMLTensor(type={self._ggml_quantization_type.name}, dequantized_shape=({self.tensor_shape})" @overload def size(self, dim: None = None) -> torch.Size: ... @@ -106,8 +106,8 @@ class GGMLTensor(torch.Tensor): def size(self, dim: int | None = None): """Return the size of the tensor after dequantization. I.e. the shape that will be used in any math ops.""" if dim is not None: - return self._tensor_shape[dim] - return self._tensor_shape + return self.tensor_shape[dim] + return self.tensor_shape @property def shape(self) -> torch.Size: # pyright: ignore[reportIncompatibleVariableOverride] pyright doesn't understand this for some reason. @@ -136,7 +136,7 @@ class GGMLTensor(torch.Tensor): elif self._ggml_quantization_type in DEQUANTIZE_FUNCTIONS: # TODO(ryand): Look into how the dtype param is intended to be used. return dequantize( - data=self.quantized_data, qtype=self._ggml_quantization_type, oshape=self._tensor_shape, dtype=None + data=self.quantized_data, qtype=self._ggml_quantization_type, oshape=self.tensor_shape, dtype=None ).to(self.compute_dtype) else: # There is no GPU implementation for this quantization type, so fallback to the numpy implementation.