mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Add a compute_dtype field to GGMLTensor.
This commit is contained in:
@@ -235,7 +235,8 @@ class FluxGGUFCheckpointModel(ModelLoader):
|
||||
|
||||
with SilenceWarnings():
|
||||
# Load the state dict and patcher
|
||||
sd = gguf_sd_loader(model_path)
|
||||
# HACK(ryand): We shouldn't be hard-coding the compute_dtype here.
|
||||
sd = gguf_sd_loader(model_path, compute_dtype=torch.bfloat16)
|
||||
model = Flux(params[config.config_path])
|
||||
model.load_state_dict(sd, assign=True)
|
||||
return model
|
||||
|
||||
@@ -412,7 +412,7 @@ class ModelProbe(object):
|
||||
assert isinstance(model, dict)
|
||||
return model
|
||||
elif model_path.suffix.endswith(".gguf"):
|
||||
return gguf_sd_loader(model_path)
|
||||
return gguf_sd_loader(model_path, compute_dtype=torch.float32)
|
||||
else:
|
||||
return safetensors.torch.load_file(model_path)
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = False) -> Dict[str
|
||||
if scan_result.infected_files != 0:
|
||||
raise Exception(f'The model file "{path}" is potentially infected by malware. Aborting import.')
|
||||
if str(path).endswith(".gguf"):
|
||||
checkpoint = gguf_sd_loader(Path(path))
|
||||
checkpoint = gguf_sd_loader(Path(path), compute_dtype=torch.float32)
|
||||
else:
|
||||
checkpoint = torch.load(path, map_location=torch.device("meta"))
|
||||
return checkpoint
|
||||
|
||||
@@ -9,43 +9,15 @@ from invokeai.backend.quantization.gguf.utils import (
|
||||
dequantize,
|
||||
)
|
||||
|
||||
# Ranking of precision preference for different dtypes.
|
||||
# When applying an operation involving a GGMLTensor and other non-GGMLTensors, we will run the operation at the
|
||||
# highest precision of the non-GGMLTensors.
|
||||
DTYPE_PRECISION_RANK = {
|
||||
torch.float64: 0,
|
||||
torch.float32: 1,
|
||||
torch.bfloat16: 2, # Note: We prefer bfloat16 over float16 for our typical use cases.
|
||||
torch.float16: 3,
|
||||
torch.float8_e4m3fn: 4,
|
||||
}
|
||||
|
||||
|
||||
def choose_highest_precision_dtype(dtypes: list[torch.dtype]) -> torch.dtype:
|
||||
if len(dtypes) == 0:
|
||||
# TODO(ryand): If we ever hit this case, there's a good chance we'll want to allow the user to specify the
|
||||
# desired compute dtype.
|
||||
return torch.float32
|
||||
return min(dtypes, key=lambda dtype: DTYPE_PRECISION_RANK[dtype])
|
||||
|
||||
|
||||
def dequantize_and_run(func, args, kwargs):
|
||||
"""A helper function for running math ops on GGMLTensor inputs.
|
||||
|
||||
Dequantizes the inputs, and runs the function.
|
||||
"""
|
||||
# Determine which precision to run the operation at.
|
||||
all_input_dtypes = [a.dtype for a in args if type(a) is torch.Tensor] + [
|
||||
v.dtype for v in kwargs.values() if type(v) is torch.Tensor
|
||||
]
|
||||
compute_dtype = choose_highest_precision_dtype(all_input_dtypes)
|
||||
|
||||
dequantized_args = [
|
||||
a.get_dequantized_tensor(dtype=compute_dtype) if hasattr(a, "get_dequantized_tensor") else a for a in args
|
||||
]
|
||||
dequantized_args = [a.get_dequantized_tensor() if hasattr(a, "get_dequantized_tensor") else a for a in args]
|
||||
dequantized_kwargs = {
|
||||
k: v.get_dequantized_tensor(dtype=compute_dtype) if hasattr(v, "get_dequantized_tensor") else v
|
||||
for k, v in kwargs.items()
|
||||
k: v.get_dequantized_tensor() if hasattr(v, "get_dequantized_tensor") else v for k, v in kwargs.items()
|
||||
}
|
||||
return func(*dequantized_args, **dequantized_kwargs)
|
||||
|
||||
@@ -67,7 +39,9 @@ def apply_to_quantized_tensor(func, args, kwargs):
|
||||
# This is intended to catch calls such as `.to(dtype-torch.float32)`, which are not supported on GGMLTensors.
|
||||
raise ValueError("Operation changed the dtype of GGMLTensor unexpectedly.")
|
||||
|
||||
return GGMLTensor(new_data, ggml_tensor._ggml_quantization_type, ggml_tensor._tensor_shape)
|
||||
return GGMLTensor(
|
||||
new_data, ggml_tensor._ggml_quantization_type, ggml_tensor._tensor_shape, ggml_tensor.compute_dtype
|
||||
)
|
||||
|
||||
|
||||
GGML_TENSOR_OP_TABLE = {
|
||||
@@ -89,7 +63,13 @@ class GGMLTensor(torch.Tensor):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def __new__(cls, data: torch.Tensor, ggml_quantization_type: gguf.GGMLQuantizationType, tensor_shape: torch.Size):
|
||||
def __new__(
|
||||
cls,
|
||||
data: torch.Tensor,
|
||||
ggml_quantization_type: gguf.GGMLQuantizationType,
|
||||
tensor_shape: torch.Size,
|
||||
compute_dtype: torch.dtype,
|
||||
):
|
||||
# Type hinting is not supported for torch.Tensor._make_wrapper_subclass, so we ignore the errors.
|
||||
return torch.Tensor._make_wrapper_subclass( # pyright: ignore
|
||||
cls,
|
||||
@@ -101,11 +81,18 @@ class GGMLTensor(torch.Tensor):
|
||||
storage_offset=data.storage_offset(),
|
||||
)
|
||||
|
||||
def __init__(self, data: torch.Tensor, ggml_quantization_type: gguf.GGMLQuantizationType, tensor_shape: torch.Size):
|
||||
def __init__(
|
||||
self,
|
||||
data: torch.Tensor,
|
||||
ggml_quantization_type: gguf.GGMLQuantizationType,
|
||||
tensor_shape: torch.Size,
|
||||
compute_dtype: torch.dtype,
|
||||
):
|
||||
self.quantized_data = data
|
||||
self._ggml_quantization_type = ggml_quantization_type
|
||||
# The dequantized shape of the tensor.
|
||||
self._tensor_shape = tensor_shape
|
||||
self.compute_dtype = compute_dtype
|
||||
|
||||
def __repr__(self, *, tensor_contents=None):
|
||||
return f"GGMLTensor(type={self._ggml_quantization_type.name}, dequantized_shape=({self._tensor_shape})"
|
||||
@@ -138,23 +125,23 @@ class GGMLTensor(torch.Tensor):
|
||||
"""
|
||||
return self
|
||||
|
||||
def get_dequantized_tensor(self, dtype: torch.dtype):
|
||||
def get_dequantized_tensor(self):
|
||||
"""Return the dequantized tensor.
|
||||
|
||||
Args:
|
||||
dtype: The dtype of the dequantized tensor.
|
||||
"""
|
||||
if self._ggml_quantization_type in TORCH_COMPATIBLE_QTYPES:
|
||||
return self.quantized_data.to(dtype)
|
||||
return self.quantized_data.to(self.compute_dtype)
|
||||
elif self._ggml_quantization_type in DEQUANTIZE_FUNCTIONS:
|
||||
# TODO(ryand): Look into how the dtype param is intended to be used.
|
||||
return dequantize(
|
||||
data=self.quantized_data, qtype=self._ggml_quantization_type, oshape=self._tensor_shape, dtype=None
|
||||
).to(dtype)
|
||||
).to(self.compute_dtype)
|
||||
else:
|
||||
# There is no GPU implementation for this quantization type, so fallback to the numpy implementation.
|
||||
new = gguf.quants.dequantize(self.quantized_data.cpu().numpy(), self._ggml_quantization_type)
|
||||
return torch.from_numpy(new).to(self.quantized_data.device, dtype=dtype)
|
||||
return torch.from_numpy(new).to(self.quantized_data.device, dtype=self.compute_dtype)
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args, kwargs):
|
||||
|
||||
@@ -7,7 +7,7 @@ from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
|
||||
from invokeai.backend.quantization.gguf.utils import TORCH_COMPATIBLE_QTYPES
|
||||
|
||||
|
||||
def gguf_sd_loader(path: Path) -> dict[str, GGMLTensor]:
|
||||
def gguf_sd_loader(path: Path, compute_dtype: torch.dtype) -> dict[str, GGMLTensor]:
|
||||
reader = gguf.GGUFReader(path)
|
||||
|
||||
sd: dict[str, GGMLTensor] = {}
|
||||
@@ -16,5 +16,7 @@ def gguf_sd_loader(path: Path) -> dict[str, GGMLTensor]:
|
||||
shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape)))
|
||||
if tensor.tensor_type in TORCH_COMPATIBLE_QTYPES:
|
||||
torch_tensor = torch_tensor.view(*shape)
|
||||
sd[tensor.name] = GGMLTensor(torch_tensor, ggml_quantization_type=tensor.tensor_type, tensor_shape=shape)
|
||||
sd[tensor.name] = GGMLTensor(
|
||||
torch_tensor, ggml_quantization_type=tensor.tensor_type, tensor_shape=shape, compute_dtype=compute_dtype
|
||||
)
|
||||
return sd
|
||||
|
||||
@@ -2,7 +2,7 @@ import gguf
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor, choose_highest_precision_dtype
|
||||
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
|
||||
|
||||
|
||||
def quantize_tensor(data: torch.Tensor, ggml_quantization_type: gguf.GGMLQuantizationType) -> GGMLTensor:
|
||||
@@ -13,7 +13,10 @@ def quantize_tensor(data: torch.Tensor, ggml_quantization_type: gguf.GGMLQuantiz
|
||||
data_np = data.detach().cpu().numpy()
|
||||
quantized_np = gguf.quantize(data_np, ggml_quantization_type)
|
||||
return GGMLTensor(
|
||||
data=torch.from_numpy(quantized_np), ggml_quantization_type=ggml_quantization_type, tensor_shape=data.shape
|
||||
data=torch.from_numpy(quantized_np),
|
||||
ggml_quantization_type=ggml_quantization_type,
|
||||
tensor_shape=data.shape,
|
||||
compute_dtype=data.dtype,
|
||||
).to(device=data.device) # type: ignore
|
||||
|
||||
|
||||
@@ -95,18 +98,6 @@ def test_ggml_tensor_to_device():
|
||||
assert torch.allclose(x_cpu.quantized_data, x_gpu.quantized_data.cpu(), atol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["dtypes", "expected_dtype"],
|
||||
[
|
||||
([], torch.float32), # Default to float32 if no dtypes are provided.
|
||||
([torch.float32, torch.float16, torch.bfloat16], torch.float32),
|
||||
([torch.float16, torch.bfloat16], torch.bfloat16),
|
||||
],
|
||||
)
|
||||
def test_choose_highest_precision_dtype(dtypes: list[torch.dtype], expected_dtype: torch.dtype):
|
||||
assert choose_highest_precision_dtype(dtypes) == expected_dtype
|
||||
|
||||
|
||||
def test_ggml_tensor_shape():
|
||||
x = torch.randn(32, 64)
|
||||
x_quantized = quantize_tensor(x, gguf.GGMLQuantizationType.Q8_0)
|
||||
|
||||
Reference in New Issue
Block a user