Run ruff and update imports

This commit is contained in:
Brandon Rising
2024-09-30 22:27:27 -04:00
committed by Kent Keirsey
parent 66bbd62758
commit 7d9f125232
4 changed files with 5 additions and 7 deletions

View File

@@ -31,7 +31,7 @@ from invokeai.backend.model_manager.config import (
)
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length, read_checkpoint_meta
from invokeai.backend.quantization.gguf.layers import GGUFTensor
from invokeai.backend.quantization.gguf.loaders import load_gguf_sd
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
from invokeai.backend.util.silence_warnings import SilenceWarnings
@@ -412,7 +412,7 @@ class ModelProbe(object):
assert isinstance(model, dict)
return model
elif model_path.suffix.endswith(".gguf"):
return load_gguf_sd(model_path)
return gguf_sd_loader(model_path)
else:
return safetensors.torch.load_file(model_path)

View File

@@ -8,7 +8,7 @@ import safetensors
import torch
from picklescan.scanner import scan_file_path
from invokeai.backend.quantization.gguf.loaders import load_gguf_sd
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
def _fast_safetensors_reader(path: str) -> Dict[str, torch.Tensor]:
@@ -57,7 +57,7 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = False) -> Dict[str
if scan_result.infected_files != 0:
raise Exception(f'The model file "{path}" is potentially infected by malware. Aborting import.')
if str(path).endswith(".gguf"):
checkpoint = load_gguf_sd(Path(path))
checkpoint = gguf_sd_loader(Path(path))
else:
checkpoint = torch.load(path, map_location=torch.device("meta"))
return checkpoint

View File

@@ -9,8 +9,6 @@ from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
from invokeai.backend.quantization.gguf.layers import GGUFTensor
from invokeai.backend.quantization.gguf.utils import TORCH_COMPATIBLE_QTYPES
TORCH_COMPATIBLE_QTYPES = {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}
def gguf_sd_loader(path: Path) -> dict[str, GGUFTensor]:
reader = gguf.GGUFReader(path)

View File

@@ -14,4 +14,4 @@ def test_ggml_tensor():
ggml_tensor = GGMLTensor(data, tensor_type, tensor_shape)
ones = torch.ones([1], dtype=torch.float32)
x = ggml_tensor * ones
_ = ggml_tensor * ones