mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-14 19:38:08 -05:00
Add workaround for FLUX GGUF models with incorrect img_in.weight shape.
This commit is contained in:
@@ -37,6 +37,7 @@ from invokeai.backend.model_manager.util.model_util import (
|
||||
convert_bundle_to_flux_transformer_checkpoint,
|
||||
)
|
||||
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
|
||||
from invokeai.backend.quantization.gguf.utils import TORCH_COMPATIBLE_QTYPES
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
try:
|
||||
@@ -234,11 +235,21 @@ class FluxGGUFCheckpointModel(ModelLoader):
|
||||
model_path = Path(config.path)
|
||||
|
||||
with SilenceWarnings():
|
||||
# Load the state dict and patcher
|
||||
model = Flux(params[config.config_path])
|
||||
|
||||
# HACK(ryand): We shouldn't be hard-coding the compute_dtype here.
|
||||
sd = gguf_sd_loader(model_path, compute_dtype=torch.bfloat16)
|
||||
model = Flux(params[config.config_path])
|
||||
model.load_state_dict(sd, assign=True)
|
||||
|
||||
# HACK(ryand): There are some broken GGUF models in circulation that have the wrong shape for img_in.weight.
|
||||
# We override the shape here to fix the issue.
|
||||
# Example model with this issue (Q4_K_M): https://civitai.com/models/705823/ggufk-flux-unchained-km-quants
|
||||
img_in_weight = sd.get("img_in.weight", None)
|
||||
if img_in_weight is not None and img_in_weight._ggml_quantization_type in TORCH_COMPATIBLE_QTYPES:
|
||||
expected_img_in_weight_shape = model.img_in.weight.shape
|
||||
img_in_weight.quantized_data = img_in_weight.quantized_data.view(expected_img_in_weight_shape)
|
||||
img_in_weight.tensor_shape = expected_img_in_weight_shape
|
||||
|
||||
model.load_state_dict(sd, assign=True)
|
||||
return model
|
||||
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ def apply_to_quantized_tensor(func, args, kwargs):
|
||||
raise ValueError("Operation changed the dtype of GGMLTensor unexpectedly.")
|
||||
|
||||
return GGMLTensor(
|
||||
new_data, ggml_tensor._ggml_quantization_type, ggml_tensor._tensor_shape, ggml_tensor.compute_dtype
|
||||
new_data, ggml_tensor._ggml_quantization_type, ggml_tensor.tensor_shape, ggml_tensor.compute_dtype
|
||||
)
|
||||
|
||||
|
||||
@@ -91,11 +91,11 @@ class GGMLTensor(torch.Tensor):
|
||||
self.quantized_data = data
|
||||
self._ggml_quantization_type = ggml_quantization_type
|
||||
# The dequantized shape of the tensor.
|
||||
self._tensor_shape = tensor_shape
|
||||
self.tensor_shape = tensor_shape
|
||||
self.compute_dtype = compute_dtype
|
||||
|
||||
def __repr__(self, *, tensor_contents=None):
|
||||
return f"GGMLTensor(type={self._ggml_quantization_type.name}, dequantized_shape=({self._tensor_shape})"
|
||||
return f"GGMLTensor(type={self._ggml_quantization_type.name}, dequantized_shape=({self.tensor_shape})"
|
||||
|
||||
@overload
|
||||
def size(self, dim: None = None) -> torch.Size: ...
|
||||
@@ -106,8 +106,8 @@ class GGMLTensor(torch.Tensor):
|
||||
def size(self, dim: int | None = None):
|
||||
"""Return the size of the tensor after dequantization. I.e. the shape that will be used in any math ops."""
|
||||
if dim is not None:
|
||||
return self._tensor_shape[dim]
|
||||
return self._tensor_shape
|
||||
return self.tensor_shape[dim]
|
||||
return self.tensor_shape
|
||||
|
||||
@property
|
||||
def shape(self) -> torch.Size: # pyright: ignore[reportIncompatibleVariableOverride] pyright doesn't understand this for some reason.
|
||||
@@ -136,7 +136,7 @@ class GGMLTensor(torch.Tensor):
|
||||
elif self._ggml_quantization_type in DEQUANTIZE_FUNCTIONS:
|
||||
# TODO(ryand): Look into how the dtype param is intended to be used.
|
||||
return dequantize(
|
||||
data=self.quantized_data, qtype=self._ggml_quantization_type, oshape=self._tensor_shape, dtype=None
|
||||
data=self.quantized_data, qtype=self._ggml_quantization_type, oshape=self.tensor_shape, dtype=None
|
||||
).to(self.compute_dtype)
|
||||
else:
|
||||
# There is no GPU implementation for this quantization type, so fallback to the numpy implementation.
|
||||
|
||||
Reference in New Issue
Block a user