Add a compute_dtype field to GGMLTensor.

This commit is contained in:
Ryan Dick
2024-10-01 19:20:28 +00:00
committed by Kent Keirsey
parent fe84013392
commit ec7e771942
6 changed files with 37 additions and 56 deletions

View File

@@ -235,7 +235,8 @@ class FluxGGUFCheckpointModel(ModelLoader):
with SilenceWarnings():
# Load the state dict and patcher
sd = gguf_sd_loader(model_path)
# HACK(ryand): We shouldn't be hard-coding the compute_dtype here.
sd = gguf_sd_loader(model_path, compute_dtype=torch.bfloat16)
model = Flux(params[config.config_path])
model.load_state_dict(sd, assign=True)
return model

View File

@@ -412,7 +412,7 @@ class ModelProbe(object):
assert isinstance(model, dict)
return model
elif model_path.suffix.endswith(".gguf"):
return gguf_sd_loader(model_path)
return gguf_sd_loader(model_path, compute_dtype=torch.float32)
else:
return safetensors.torch.load_file(model_path)

View File

@@ -57,7 +57,7 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = False) -> Dict[str
if scan_result.infected_files != 0:
raise Exception(f'The model file "{path}" is potentially infected by malware. Aborting import.')
if str(path).endswith(".gguf"):
checkpoint = gguf_sd_loader(Path(path))
checkpoint = gguf_sd_loader(Path(path), compute_dtype=torch.float32)
else:
checkpoint = torch.load(path, map_location=torch.device("meta"))
return checkpoint

View File

@@ -9,43 +9,15 @@ from invokeai.backend.quantization.gguf.utils import (
dequantize,
)
# Ranking of precision preference for different dtypes.
# When applying an operation involving a GGMLTensor and other non-GGMLTensors, we will run the operation at the
# highest precision of the non-GGMLTensors.
DTYPE_PRECISION_RANK = {
torch.float64: 0,
torch.float32: 1,
torch.bfloat16: 2, # Note: We prefer bfloat16 over float16 for our typical use cases.
torch.float16: 3,
torch.float8_e4m3fn: 4,
}
def choose_highest_precision_dtype(dtypes: list[torch.dtype]) -> torch.dtype:
if len(dtypes) == 0:
# TODO(ryand): If we ever hit this case, there's a good chance we'll want to allow the user to specify the
# desired compute dtype.
return torch.float32
return min(dtypes, key=lambda dtype: DTYPE_PRECISION_RANK[dtype])
def dequantize_and_run(func, args, kwargs):
"""A helper function for running math ops on GGMLTensor inputs.
Dequantizes the inputs, and runs the function.
"""
# Determine which precision to run the operation at.
all_input_dtypes = [a.dtype for a in args if type(a) is torch.Tensor] + [
v.dtype for v in kwargs.values() if type(v) is torch.Tensor
]
compute_dtype = choose_highest_precision_dtype(all_input_dtypes)
dequantized_args = [
a.get_dequantized_tensor(dtype=compute_dtype) if hasattr(a, "get_dequantized_tensor") else a for a in args
]
dequantized_args = [a.get_dequantized_tensor() if hasattr(a, "get_dequantized_tensor") else a for a in args]
dequantized_kwargs = {
k: v.get_dequantized_tensor(dtype=compute_dtype) if hasattr(v, "get_dequantized_tensor") else v
for k, v in kwargs.items()
k: v.get_dequantized_tensor() if hasattr(v, "get_dequantized_tensor") else v for k, v in kwargs.items()
}
return func(*dequantized_args, **dequantized_kwargs)
@@ -67,7 +39,9 @@ def apply_to_quantized_tensor(func, args, kwargs):
# This is intended to catch calls such as `.to(dtype-torch.float32)`, which are not supported on GGMLTensors.
raise ValueError("Operation changed the dtype of GGMLTensor unexpectedly.")
return GGMLTensor(new_data, ggml_tensor._ggml_quantization_type, ggml_tensor._tensor_shape)
return GGMLTensor(
new_data, ggml_tensor._ggml_quantization_type, ggml_tensor._tensor_shape, ggml_tensor.compute_dtype
)
GGML_TENSOR_OP_TABLE = {
@@ -89,7 +63,13 @@ class GGMLTensor(torch.Tensor):
"""
@staticmethod
def __new__(cls, data: torch.Tensor, ggml_quantization_type: gguf.GGMLQuantizationType, tensor_shape: torch.Size):
def __new__(
cls,
data: torch.Tensor,
ggml_quantization_type: gguf.GGMLQuantizationType,
tensor_shape: torch.Size,
compute_dtype: torch.dtype,
):
# Type hinting is not supported for torch.Tensor._make_wrapper_subclass, so we ignore the errors.
return torch.Tensor._make_wrapper_subclass( # pyright: ignore
cls,
@@ -101,11 +81,18 @@ class GGMLTensor(torch.Tensor):
storage_offset=data.storage_offset(),
)
def __init__(self, data: torch.Tensor, ggml_quantization_type: gguf.GGMLQuantizationType, tensor_shape: torch.Size):
def __init__(
self,
data: torch.Tensor,
ggml_quantization_type: gguf.GGMLQuantizationType,
tensor_shape: torch.Size,
compute_dtype: torch.dtype,
):
self.quantized_data = data
self._ggml_quantization_type = ggml_quantization_type
# The dequantized shape of the tensor.
self._tensor_shape = tensor_shape
self.compute_dtype = compute_dtype
def __repr__(self, *, tensor_contents=None):
return f"GGMLTensor(type={self._ggml_quantization_type.name}, dequantized_shape=({self._tensor_shape})"
@@ -138,23 +125,23 @@ class GGMLTensor(torch.Tensor):
"""
return self
def get_dequantized_tensor(self, dtype: torch.dtype):
def get_dequantized_tensor(self):
"""Return the dequantized tensor.
Args:
dtype: The dtype of the dequantized tensor.
"""
if self._ggml_quantization_type in TORCH_COMPATIBLE_QTYPES:
return self.quantized_data.to(dtype)
return self.quantized_data.to(self.compute_dtype)
elif self._ggml_quantization_type in DEQUANTIZE_FUNCTIONS:
# TODO(ryand): Look into how the dtype param is intended to be used.
return dequantize(
data=self.quantized_data, qtype=self._ggml_quantization_type, oshape=self._tensor_shape, dtype=None
).to(dtype)
).to(self.compute_dtype)
else:
# There is no GPU implementation for this quantization type, so fallback to the numpy implementation.
new = gguf.quants.dequantize(self.quantized_data.cpu().numpy(), self._ggml_quantization_type)
return torch.from_numpy(new).to(self.quantized_data.device, dtype=dtype)
return torch.from_numpy(new).to(self.quantized_data.device, dtype=self.compute_dtype)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):

View File

@@ -7,7 +7,7 @@ from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
from invokeai.backend.quantization.gguf.utils import TORCH_COMPATIBLE_QTYPES
def gguf_sd_loader(path: Path) -> dict[str, GGMLTensor]:
def gguf_sd_loader(path: Path, compute_dtype: torch.dtype) -> dict[str, GGMLTensor]:
reader = gguf.GGUFReader(path)
sd: dict[str, GGMLTensor] = {}
@@ -16,5 +16,7 @@ def gguf_sd_loader(path: Path) -> dict[str, GGMLTensor]:
shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape)))
if tensor.tensor_type in TORCH_COMPATIBLE_QTYPES:
torch_tensor = torch_tensor.view(*shape)
sd[tensor.name] = GGMLTensor(torch_tensor, ggml_quantization_type=tensor.tensor_type, tensor_shape=shape)
sd[tensor.name] = GGMLTensor(
torch_tensor, ggml_quantization_type=tensor.tensor_type, tensor_shape=shape, compute_dtype=compute_dtype
)
return sd

View File

@@ -2,7 +2,7 @@ import gguf
import pytest
import torch
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor, choose_highest_precision_dtype
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
def quantize_tensor(data: torch.Tensor, ggml_quantization_type: gguf.GGMLQuantizationType) -> GGMLTensor:
@@ -13,7 +13,10 @@ def quantize_tensor(data: torch.Tensor, ggml_quantization_type: gguf.GGMLQuantiz
data_np = data.detach().cpu().numpy()
quantized_np = gguf.quantize(data_np, ggml_quantization_type)
return GGMLTensor(
data=torch.from_numpy(quantized_np), ggml_quantization_type=ggml_quantization_type, tensor_shape=data.shape
data=torch.from_numpy(quantized_np),
ggml_quantization_type=ggml_quantization_type,
tensor_shape=data.shape,
compute_dtype=data.dtype,
).to(device=data.device) # type: ignore
@@ -95,18 +98,6 @@ def test_ggml_tensor_to_device():
assert torch.allclose(x_cpu.quantized_data, x_gpu.quantized_data.cpu(), atol=1e-5)
@pytest.mark.parametrize(
["dtypes", "expected_dtype"],
[
([], torch.float32), # Default to float32 if no dtypes are provided.
([torch.float32, torch.float16, torch.bfloat16], torch.float32),
([torch.float16, torch.bfloat16], torch.bfloat16),
],
)
def test_choose_highest_precision_dtype(dtypes: list[torch.dtype], expected_dtype: torch.dtype):
assert choose_highest_precision_dtype(dtypes) == expected_dtype
def test_ggml_tensor_shape():
x = torch.randn(32, 64)
x_quantized = quantize_tensor(x, gguf.GGMLQuantizationType.Q8_0)