Compare commits

...

5 Commits

Author SHA1 Message Date
psychedelicious
34457fc381 tests: monkeypatch more references to gguf_sd_loader()
The new util in de9f541bf60199e14d159dc3511482ae38a5cb60 alters import order, breaking some model probe tests. We need to patch more references to `gguf_sd_loader()` to fix em
2025-05-20 11:29:28 +10:00
psychedelicious
88ea6b538d tests: add test for get_flux_in_channels_from_state_dict() 2025-05-20 11:29:28 +10:00
psychedelicious
9707724ec2 docs: add reminder to FLUX variant probing to add test cases if we have a probe failure 2025-05-20 11:29:28 +10:00
psychedelicious
ae65073711 tests: add stripped models for FLUX varieties
Stripped models for:
- FLUX Dev.safetensors
- FLUX Schnell.safetensors
- FLUX Fill.safetensors
- FLUX Dev (Quantized).safetensors
- FLUX Schnell (Quantized).safetensors
- flux1-fill-dev-Q8_0.gguf
- midjourneyReplica_flux1Dev.safetensors
2025-05-20 11:29:28 +10:00
psychedelicious
558dcf8cea tests: monkeypatch secondary reference to gguf_sd_loader()
`gguf_sd_loader()` has multiple references in the codebase. It is imported before monkeypatching, so we need to monkeypatch another reference to it. This fixes tests for `ModelOnDisk.load_state_dict()`.
2025-05-20 11:29:28 +10:00
10 changed files with 61 additions and 0 deletions

View File

@@ -572,6 +572,8 @@ class CheckpointProbeBase(ProbeBase):
if in_channels is None:
# If we cannot find the in_channels, we assume that this is a normal variant. Log a warning.
# If this occurs, we should add a test case for the affected model here:
# tests/backend/flux/test_flux_state_dict_utils.py
logger.warning(
f"{self.model_path} does not have img_in.weight or model.diffusion_model.img_in.weight key. Assuming normal variant."
)

View File

@@ -0,0 +1,35 @@
from pathlib import Path
import pytest
from invokeai.backend.flux.flux_state_dict_utils import get_flux_in_channels_from_state_dict
from invokeai.backend.model_manager.config import ModelOnDisk
test_cases = [
# Unquantized
("FLUX Dev.safetensors", 64),
("FLUX Schnell.safetensors", 64),
("FLUX Fill.safetensors", 384),
# BNB-NF4 quantized
("FLUX Dev (Quantized).safetensors", 1), # BNB-NF4
("FLUX Schnell (Quantized).safetensors", 1), # BNB-NF4
# GGUF quantized FLUX Fill
("flux1-fill-dev-Q8_0.gguf", 384),
# Fine-tune w/ "model.diffusion_model.img_in.weight" instead of "img_in.weight"
("midjourneyReplica_flux1Dev.safetensors", 64),
# Not a FLUX model, testing fallback case
("Noodles Style.safetensors", None),
]
@pytest.mark.parametrize("model_file_name,expected_in_channels", test_cases)
def test_get_flux_in_channels_from_state_dict(model_file_name: str, expected_in_channels: int, override_model_loading):
model_path = Path(f"tests/test_model_probe/stripped_models/{model_file_name}")
mod = ModelOnDisk(model_path)
state_dict = mod.load_state_dict()
in_channels = get_flux_in_channels_from_state_dict(state_dict)
assert in_channels == expected_in_channels

View File

@@ -96,6 +96,9 @@ def override_model_loading(monkeypatch):
monkeypatch.setattr(safetensors.torch, "load", load_stripped_model)
monkeypatch.setattr(safetensors.torch, "load_file", load_stripped_model)
monkeypatch.setattr(gguf_loaders, "gguf_sd_loader", load_stripped_model)
monkeypatch.setattr("invokeai.backend.model_manager.config.gguf_sd_loader", load_stripped_model)
monkeypatch.setattr("invokeai.backend.model_manager.util.model_util.gguf_sd_loader", load_stripped_model)
monkeypatch.setattr("invokeai.backend.model_manager.legacy_probe.gguf_sd_loader", load_stripped_model)
def fake_scan(*args, **kwargs):
return SimpleNamespace(infected_files=0, scan_err=None)

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:fe25212279fec351340d1c4a9da0eb902af82162350970c148bf331c1c02f3c5
size 292730

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:84850676ab6fc163b4fe3bb87b1584a5a78b523e5f6e58b6ecb2c7d34e4c0796
size 130743

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:eb64744f32674cd1e8c3c09e578d18e1ca84c3deac0ef0a2fc3654ec9ac0a84d
size 130744

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:42cd75dbd5dec6252de6f959a6ed678fb0e5bef166eca7ac38c51577a0d4e4eb
size 291091

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a1533dced878ca5a8bae39bfdbed85dfd97e937ec3c97540da1e7d4011ffed98
size 130098

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:cac069dd904e0d676baacecfeaba52bbbe808a6d755dabdd94c7281656fa0507
size 129356

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:98d0f54489ec096f543a9b8f88683fd960acd96521d987e027be9e23d621d96f
size 151803