diff --git a/invokeai/backend/quantization/gguf/ggml_tensor.py b/invokeai/backend/quantization/gguf/ggml_tensor.py index 631709b6e3..208d0f396b 100644 --- a/invokeai/backend/quantization/gguf/ggml_tensor.py +++ b/invokeai/backend/quantization/gguf/ggml_tensor.py @@ -123,6 +123,12 @@ class GGMLTensor(torch.Tensor): def to(self, *args, **kwargs) -> torch.Tensor: ... def to(self, *args, **kwargs): + for func_arg in args: + if isinstance(func_arg, torch.dtype) and func_arg != self.quantized_data.dtype: + raise ValueError("Operation changed the dtype of GGMLTensor unexpectedly.") + if 'dtype' in kwargs.keys(): + if kwargs['dtype'] != self.quantized_data.dtype: + raise ValueError("Operation changed the dtype of GGMLTensor unexpectedly." self.quantized_data = self.quantized_data.to(*args, **kwargs) return self