mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-18 01:37:56 -05:00
Compare commits
5 Commits
controlnet
...
psyche/tes
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
34457fc381 | ||
|
|
88ea6b538d | ||
|
|
9707724ec2 | ||
|
|
ae65073711 | ||
|
|
558dcf8cea |
@@ -572,6 +572,8 @@ class CheckpointProbeBase(ProbeBase):
|
||||
|
||||
if in_channels is None:
|
||||
# If we cannot find the in_channels, we assume that this is a normal variant. Log a warning.
|
||||
# If this occurs, we should add a test case for the affected model here:
|
||||
# tests/backend/flux/test_flux_state_dict_utils.py
|
||||
logger.warning(
|
||||
f"{self.model_path} does not have img_in.weight or model.diffusion_model.img_in.weight key. Assuming normal variant."
|
||||
)
|
||||
|
||||
35
tests/backend/flux/test_flux_state_dict_utils.py
Normal file
35
tests/backend/flux/test_flux_state_dict_utils.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from invokeai.backend.flux.flux_state_dict_utils import get_flux_in_channels_from_state_dict
|
||||
from invokeai.backend.model_manager.config import ModelOnDisk
|
||||
|
||||
test_cases = [
|
||||
# Unquantized
|
||||
("FLUX Dev.safetensors", 64),
|
||||
("FLUX Schnell.safetensors", 64),
|
||||
("FLUX Fill.safetensors", 384),
|
||||
# BNB-NF4 quantized
|
||||
("FLUX Dev (Quantized).safetensors", 1), # BNB-NF4
|
||||
("FLUX Schnell (Quantized).safetensors", 1), # BNB-NF4
|
||||
# GGUF quantized FLUX Fill
|
||||
("flux1-fill-dev-Q8_0.gguf", 384),
|
||||
# Fine-tune w/ "model.diffusion_model.img_in.weight" instead of "img_in.weight"
|
||||
("midjourneyReplica_flux1Dev.safetensors", 64),
|
||||
# Not a FLUX model, testing fallback case
|
||||
("Noodles Style.safetensors", None),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_file_name,expected_in_channels", test_cases)
|
||||
def test_get_flux_in_channels_from_state_dict(model_file_name: str, expected_in_channels: int, override_model_loading):
|
||||
model_path = Path(f"tests/test_model_probe/stripped_models/{model_file_name}")
|
||||
|
||||
mod = ModelOnDisk(model_path)
|
||||
|
||||
state_dict = mod.load_state_dict()
|
||||
|
||||
in_channels = get_flux_in_channels_from_state_dict(state_dict)
|
||||
|
||||
assert in_channels == expected_in_channels
|
||||
@@ -96,6 +96,9 @@ def override_model_loading(monkeypatch):
|
||||
monkeypatch.setattr(safetensors.torch, "load", load_stripped_model)
|
||||
monkeypatch.setattr(safetensors.torch, "load_file", load_stripped_model)
|
||||
monkeypatch.setattr(gguf_loaders, "gguf_sd_loader", load_stripped_model)
|
||||
monkeypatch.setattr("invokeai.backend.model_manager.config.gguf_sd_loader", load_stripped_model)
|
||||
monkeypatch.setattr("invokeai.backend.model_manager.util.model_util.gguf_sd_loader", load_stripped_model)
|
||||
monkeypatch.setattr("invokeai.backend.model_manager.legacy_probe.gguf_sd_loader", load_stripped_model)
|
||||
|
||||
def fake_scan(*args, **kwargs):
|
||||
return SimpleNamespace(infected_files=0, scan_err=None)
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:fe25212279fec351340d1c4a9da0eb902af82162350970c148bf331c1c02f3c5
|
||||
size 292730
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:84850676ab6fc163b4fe3bb87b1584a5a78b523e5f6e58b6ecb2c7d34e4c0796
|
||||
size 130743
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:eb64744f32674cd1e8c3c09e578d18e1ca84c3deac0ef0a2fc3654ec9ac0a84d
|
||||
size 130744
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:42cd75dbd5dec6252de6f959a6ed678fb0e5bef166eca7ac38c51577a0d4e4eb
|
||||
size 291091
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a1533dced878ca5a8bae39bfdbed85dfd97e937ec3c97540da1e7d4011ffed98
|
||||
size 130098
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:cac069dd904e0d676baacecfeaba52bbbe808a6d755dabdd94c7281656fa0507
|
||||
size 129356
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:98d0f54489ec096f543a9b8f88683fd960acd96521d987e027be9e23d621d96f
|
||||
size 151803
|
||||
Reference in New Issue
Block a user