mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
test(recall): cover loras, control layers, and ip_adapters paths
The original recall_parameters router (PR #8758) shipped without any unit tests for its three collection fields. This commit backfills that coverage alongside the reference_images tests added in the previous commit. The resolver helpers (resolve_model_name_to_key, load_image_file, process_controlnet_image) are monkey-patched via module-level attribute replacement so each test can pin down a specific resolution outcome without spinning up the model manager or an image-files service. Two small factory helpers (make_name_to_key_stub / make_load_image_file_stub) make that ergonomic. New coverage: * LoRAs — multi-entry resolution + weight/is_enabled pass-through, silent drop on unresolvable names, is_enabled default of True. * Control layers — ControlNet resolution precedence, fall-through to T2I Adapter and Control LoRA in order, missing image gracefully warned-and-continued, processed_image attached when the processor returns data, unresolvable entries dropped. * IP Adapters — IPAdapter-before-FluxRedux lookup order, method / image_influence pass-through, missing image gracefully warned-and- continued, unresolvable entries dropped. * Combined happy path — full request with prompts + model + all four collection fields, verifying every resolved value reaches the broadcast payload. * Main-model drop — an unresolvable main model is scrubbed from the broadcast so the frontend never receives a stale model name. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -1,13 +1,14 @@
|
||||
"""Tests for the recall parameters router.
|
||||
|
||||
Focused on the ``reference_images`` field added for model-free reference
|
||||
images (FLUX.2 Klein, FLUX Kontext, Qwen Image Edit). The existing
|
||||
``loras`` / ``control_layers`` / ``ip_adapters`` paths are exercised via
|
||||
integration tests elsewhere; this file pins down the new field's
|
||||
request-validation, resolver behavior, and event payload.
|
||||
These tests monkey-patch the heavy-weight lookup helpers
|
||||
(``resolve_model_name_to_key``, ``load_image_file``,
|
||||
``process_controlnet_image``) rather than wiring up a real model manager
|
||||
or image-files service. This keeps each test focused on the router's
|
||||
request-validation, resolver sequencing, and broadcast payload shape.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Optional
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
@@ -16,6 +17,7 @@ from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.api.routers import recall_parameters as recall_module
|
||||
from invokeai.app.api_app import app
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.backend.model_manager.taxonomy import ModelType
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -53,6 +55,36 @@ def patched_dependencies(monkeypatch: Any, mock_invoker: Invoker) -> MockApiDepe
|
||||
return dependencies
|
||||
|
||||
|
||||
def make_name_to_key_stub(
|
||||
mapping: dict[tuple[str, ModelType], str],
|
||||
) -> Callable[[str, ModelType], Optional[str]]:
|
||||
"""Build a ``resolve_model_name_to_key`` stand-in from a (name, type) dict.
|
||||
|
||||
Any lookup that is not present in ``mapping`` returns ``None``, mirroring
|
||||
what the real resolver does when the model manager cannot find a match.
|
||||
"""
|
||||
|
||||
def _lookup(model_name: str, model_type: ModelType = ModelType.Main) -> Optional[str]:
|
||||
return mapping.get((model_name, model_type))
|
||||
|
||||
return _lookup
|
||||
|
||||
|
||||
def make_load_image_file_stub(
|
||||
known_images: dict[str, tuple[int, int]],
|
||||
) -> Callable[[str], Optional[dict[str, Any]]]:
|
||||
"""Build a ``load_image_file`` stand-in from a name → (width, height) dict."""
|
||||
|
||||
def _load(image_name: str) -> Optional[dict[str, Any]]:
|
||||
dims = known_images.get(image_name)
|
||||
if dims is None:
|
||||
return None
|
||||
width, height = dims
|
||||
return {"image_name": image_name, "width": width, "height": height}
|
||||
|
||||
return _load
|
||||
|
||||
|
||||
class TestReferenceImagesRecall:
|
||||
def test_reference_images_forwarded_when_image_exists(
|
||||
self, monkeypatch: Any, patched_dependencies: MockApiDependencies, client: TestClient
|
||||
@@ -157,3 +189,443 @@ class TestReferenceImagesRecall:
|
||||
params = response.json()["parameters"]
|
||||
assert params["positive_prompt"] == "hello"
|
||||
assert params["reference_images"] == []
|
||||
|
||||
|
||||
class TestLorasRecall:
|
||||
def test_multiple_loras_resolved_with_weights_and_is_enabled(
|
||||
self, monkeypatch: Any, patched_dependencies: MockApiDependencies, client: TestClient
|
||||
) -> None:
|
||||
"""Each LoRA's model name is resolved to a key and weight/is_enabled pass through."""
|
||||
monkeypatch.setattr(
|
||||
recall_module,
|
||||
"resolve_model_name_to_key",
|
||||
make_name_to_key_stub(
|
||||
{
|
||||
("detail-lora", ModelType.LoRA): "key-detail",
|
||||
("style-lora", ModelType.LoRA): "key-style",
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/recall/default",
|
||||
json={
|
||||
"loras": [
|
||||
{"model_name": "detail-lora", "weight": 0.8, "is_enabled": True},
|
||||
{"model_name": "style-lora", "weight": 0.5, "is_enabled": False},
|
||||
]
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
loras = response.json()["parameters"]["loras"]
|
||||
assert loras == [
|
||||
{"model_key": "key-detail", "weight": 0.8, "is_enabled": True},
|
||||
{"model_key": "key-style", "weight": 0.5, "is_enabled": False},
|
||||
]
|
||||
|
||||
def test_unresolvable_loras_are_dropped(
|
||||
self, monkeypatch: Any, patched_dependencies: MockApiDependencies, client: TestClient
|
||||
) -> None:
|
||||
"""LoRAs whose names do not resolve are silently skipped — not an error."""
|
||||
monkeypatch.setattr(
|
||||
recall_module,
|
||||
"resolve_model_name_to_key",
|
||||
make_name_to_key_stub({("keeper", ModelType.LoRA): "key-keeper"}),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/recall/default",
|
||||
json={
|
||||
"loras": [
|
||||
{"model_name": "keeper", "weight": 0.7},
|
||||
{"model_name": "ghost-lora"},
|
||||
]
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
loras = response.json()["parameters"]["loras"]
|
||||
assert len(loras) == 1
|
||||
assert loras[0]["model_key"] == "key-keeper"
|
||||
|
||||
def test_is_enabled_defaults_to_true(
|
||||
self, monkeypatch: Any, patched_dependencies: MockApiDependencies, client: TestClient
|
||||
) -> None:
|
||||
"""Omitting is_enabled should default to True per the pydantic schema."""
|
||||
monkeypatch.setattr(
|
||||
recall_module,
|
||||
"resolve_model_name_to_key",
|
||||
make_name_to_key_stub({("x", ModelType.LoRA): "key-x"}),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/recall/default",
|
||||
json={"loras": [{"model_name": "x"}]},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["parameters"]["loras"][0]["is_enabled"] is True
|
||||
|
||||
|
||||
class TestControlLayersRecall:
|
||||
def test_controlnet_resolution_takes_precedence(
|
||||
self, monkeypatch: Any, patched_dependencies: MockApiDependencies, client: TestClient
|
||||
) -> None:
|
||||
"""A name that matches a ControlNet model should resolve to it directly."""
|
||||
monkeypatch.setattr(
|
||||
recall_module,
|
||||
"resolve_model_name_to_key",
|
||||
make_name_to_key_stub({("canny", ModelType.ControlNet): "key-canny"}),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
recall_module,
|
||||
"load_image_file",
|
||||
make_load_image_file_stub({"ctl.png": (512, 512)}),
|
||||
)
|
||||
monkeypatch.setattr(recall_module, "process_controlnet_image", lambda *a, **kw: None)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/recall/default",
|
||||
json={
|
||||
"control_layers": [
|
||||
{
|
||||
"model_name": "canny",
|
||||
"image_name": "ctl.png",
|
||||
"weight": 0.75,
|
||||
"begin_step_percent": 0.1,
|
||||
"end_step_percent": 0.9,
|
||||
"control_mode": "balanced",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
layer = response.json()["parameters"]["control_layers"][0]
|
||||
assert layer["model_key"] == "key-canny"
|
||||
assert layer["weight"] == 0.75
|
||||
assert layer["begin_step_percent"] == 0.1
|
||||
assert layer["end_step_percent"] == 0.9
|
||||
assert layer["control_mode"] == "balanced"
|
||||
assert layer["image"] == {"image_name": "ctl.png", "width": 512, "height": 512}
|
||||
# processor returned None → no processed_image field
|
||||
assert "processed_image" not in layer
|
||||
|
||||
def test_falls_back_to_t2i_adapter(
|
||||
self, monkeypatch: Any, patched_dependencies: MockApiDependencies, client: TestClient
|
||||
) -> None:
|
||||
"""When no ControlNet match exists, T2I Adapter is tried next."""
|
||||
monkeypatch.setattr(
|
||||
recall_module,
|
||||
"resolve_model_name_to_key",
|
||||
make_name_to_key_stub({("sketchy", ModelType.T2IAdapter): "key-t2i"}),
|
||||
)
|
||||
monkeypatch.setattr(recall_module, "load_image_file", make_load_image_file_stub({}))
|
||||
monkeypatch.setattr(recall_module, "process_controlnet_image", lambda *a, **kw: None)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/recall/default",
|
||||
json={"control_layers": [{"model_name": "sketchy", "weight": 1.0}]},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["parameters"]["control_layers"][0]["model_key"] == "key-t2i"
|
||||
|
||||
def test_falls_back_to_control_lora(
|
||||
self, monkeypatch: Any, patched_dependencies: MockApiDependencies, client: TestClient
|
||||
) -> None:
|
||||
"""When neither ControlNet nor T2I Adapter matches, Control LoRA is tried last."""
|
||||
monkeypatch.setattr(
|
||||
recall_module,
|
||||
"resolve_model_name_to_key",
|
||||
make_name_to_key_stub({("clora", ModelType.LoRA): "key-clora"}),
|
||||
)
|
||||
monkeypatch.setattr(recall_module, "load_image_file", make_load_image_file_stub({}))
|
||||
monkeypatch.setattr(recall_module, "process_controlnet_image", lambda *a, **kw: None)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/recall/default",
|
||||
json={"control_layers": [{"model_name": "clora", "weight": 1.0}]},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["parameters"]["control_layers"][0]["model_key"] == "key-clora"
|
||||
|
||||
def test_missing_image_still_resolves_config(
|
||||
self, monkeypatch: Any, patched_dependencies: MockApiDependencies, client: TestClient
|
||||
) -> None:
|
||||
"""A missing control image is warned about but does not block the rest of the config."""
|
||||
monkeypatch.setattr(
|
||||
recall_module,
|
||||
"resolve_model_name_to_key",
|
||||
make_name_to_key_stub({("canny", ModelType.ControlNet): "key-canny"}),
|
||||
)
|
||||
monkeypatch.setattr(recall_module, "load_image_file", make_load_image_file_stub({}))
|
||||
monkeypatch.setattr(recall_module, "process_controlnet_image", lambda *a, **kw: None)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/recall/default",
|
||||
json={
|
||||
"control_layers": [
|
||||
{
|
||||
"model_name": "canny",
|
||||
"image_name": "missing.png",
|
||||
"weight": 0.75,
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
layer = response.json()["parameters"]["control_layers"][0]
|
||||
assert layer["model_key"] == "key-canny"
|
||||
assert layer["weight"] == 0.75
|
||||
assert "image" not in layer
|
||||
assert "processed_image" not in layer
|
||||
|
||||
def test_processed_image_included_when_processor_returns_data(
|
||||
self, monkeypatch: Any, patched_dependencies: MockApiDependencies, client: TestClient
|
||||
) -> None:
|
||||
"""When the processor produces a derived image, it is attached to the resolved layer."""
|
||||
monkeypatch.setattr(
|
||||
recall_module,
|
||||
"resolve_model_name_to_key",
|
||||
make_name_to_key_stub({("canny", ModelType.ControlNet): "key-canny"}),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
recall_module,
|
||||
"load_image_file",
|
||||
make_load_image_file_stub({"ctl.png": (768, 768)}),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
recall_module,
|
||||
"process_controlnet_image",
|
||||
lambda image_name, model_key, services: {
|
||||
"image_name": f"processed-{image_name}",
|
||||
"width": 768,
|
||||
"height": 768,
|
||||
},
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/recall/default",
|
||||
json={"control_layers": [{"model_name": "canny", "image_name": "ctl.png", "weight": 1.0}]},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
layer = response.json()["parameters"]["control_layers"][0]
|
||||
assert layer["processed_image"]["image_name"] == "processed-ctl.png"
|
||||
|
||||
def test_unresolvable_control_layers_are_dropped(
|
||||
self, monkeypatch: Any, patched_dependencies: MockApiDependencies, client: TestClient
|
||||
) -> None:
|
||||
"""Control entries whose model doesn't resolve by any type are skipped."""
|
||||
monkeypatch.setattr(
|
||||
recall_module,
|
||||
"resolve_model_name_to_key",
|
||||
make_name_to_key_stub({}),
|
||||
)
|
||||
monkeypatch.setattr(recall_module, "load_image_file", make_load_image_file_stub({}))
|
||||
monkeypatch.setattr(recall_module, "process_controlnet_image", lambda *a, **kw: None)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/recall/default",
|
||||
json={"control_layers": [{"model_name": "unknown", "weight": 1.0}]},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["parameters"]["control_layers"] == []
|
||||
|
||||
|
||||
class TestIPAdaptersRecall:
|
||||
def test_ip_adapter_resolved_with_image_and_method(
|
||||
self, monkeypatch: Any, patched_dependencies: MockApiDependencies, client: TestClient
|
||||
) -> None:
|
||||
"""IPAdapter lookup is tried first and all config fields pass through."""
|
||||
monkeypatch.setattr(
|
||||
recall_module,
|
||||
"resolve_model_name_to_key",
|
||||
make_name_to_key_stub({("ipa-face", ModelType.IPAdapter): "key-ipa"}),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
recall_module,
|
||||
"load_image_file",
|
||||
make_load_image_file_stub({"ref.png": (1024, 1024)}),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/recall/default",
|
||||
json={
|
||||
"ip_adapters": [
|
||||
{
|
||||
"model_name": "ipa-face",
|
||||
"image_name": "ref.png",
|
||||
"weight": 0.7,
|
||||
"begin_step_percent": 0.0,
|
||||
"end_step_percent": 0.8,
|
||||
"method": "style",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
adapter = response.json()["parameters"]["ip_adapters"][0]
|
||||
assert adapter["model_key"] == "key-ipa"
|
||||
assert adapter["weight"] == 0.7
|
||||
assert adapter["begin_step_percent"] == 0.0
|
||||
assert adapter["end_step_percent"] == 0.8
|
||||
assert adapter["method"] == "style"
|
||||
assert adapter["image"] == {"image_name": "ref.png", "width": 1024, "height": 1024}
|
||||
# image_influence was not sent, so it must not appear in the resolved config
|
||||
assert "image_influence" not in adapter
|
||||
|
||||
def test_falls_back_to_flux_redux(
|
||||
self, monkeypatch: Any, patched_dependencies: MockApiDependencies, client: TestClient
|
||||
) -> None:
|
||||
"""When the name doesn't match an IPAdapter, FluxRedux is tried next."""
|
||||
monkeypatch.setattr(
|
||||
recall_module,
|
||||
"resolve_model_name_to_key",
|
||||
make_name_to_key_stub({("redux-1", ModelType.FluxRedux): "key-redux"}),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
recall_module,
|
||||
"load_image_file",
|
||||
make_load_image_file_stub({"ref.png": (512, 512)}),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/recall/default",
|
||||
json={
|
||||
"ip_adapters": [
|
||||
{
|
||||
"model_name": "redux-1",
|
||||
"image_name": "ref.png",
|
||||
"weight": 1.0,
|
||||
"image_influence": "high",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
adapter = response.json()["parameters"]["ip_adapters"][0]
|
||||
assert adapter["model_key"] == "key-redux"
|
||||
assert adapter["image_influence"] == "high"
|
||||
|
||||
def test_missing_image_still_resolves_config(
|
||||
self, monkeypatch: Any, patched_dependencies: MockApiDependencies, client: TestClient
|
||||
) -> None:
|
||||
"""A missing reference image is warned about but the adapter still lands."""
|
||||
monkeypatch.setattr(
|
||||
recall_module,
|
||||
"resolve_model_name_to_key",
|
||||
make_name_to_key_stub({("ipa", ModelType.IPAdapter): "key-ipa"}),
|
||||
)
|
||||
monkeypatch.setattr(recall_module, "load_image_file", make_load_image_file_stub({}))
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/recall/default",
|
||||
json={"ip_adapters": [{"model_name": "ipa", "image_name": "missing.png", "weight": 0.5}]},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
adapter = response.json()["parameters"]["ip_adapters"][0]
|
||||
assert adapter["model_key"] == "key-ipa"
|
||||
assert adapter["weight"] == 0.5
|
||||
assert "image" not in adapter
|
||||
|
||||
def test_unresolvable_ip_adapters_are_dropped(
|
||||
self, monkeypatch: Any, patched_dependencies: MockApiDependencies, client: TestClient
|
||||
) -> None:
|
||||
"""Adapters whose model can't be resolved (neither IPAdapter nor FluxRedux) are skipped."""
|
||||
monkeypatch.setattr(
|
||||
recall_module,
|
||||
"resolve_model_name_to_key",
|
||||
make_name_to_key_stub({}),
|
||||
)
|
||||
monkeypatch.setattr(recall_module, "load_image_file", make_load_image_file_stub({}))
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/recall/default",
|
||||
json={"ip_adapters": [{"model_name": "unknown", "weight": 1.0}]},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["parameters"]["ip_adapters"] == []
|
||||
|
||||
|
||||
class TestCombinedRecall:
|
||||
def test_all_collection_fields_together(
|
||||
self, monkeypatch: Any, patched_dependencies: MockApiDependencies, client: TestClient
|
||||
) -> None:
|
||||
"""Exercise the full happy path: prompts, model, loras, control_layers, ip_adapters, reference_images."""
|
||||
monkeypatch.setattr(
|
||||
recall_module,
|
||||
"resolve_model_name_to_key",
|
||||
make_name_to_key_stub(
|
||||
{
|
||||
("my-model", ModelType.Main): "key-main",
|
||||
("detail-lora", ModelType.LoRA): "key-lora",
|
||||
("canny", ModelType.ControlNet): "key-canny",
|
||||
("ipa-face", ModelType.IPAdapter): "key-ipa",
|
||||
}
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
recall_module,
|
||||
"load_image_file",
|
||||
make_load_image_file_stub(
|
||||
{
|
||||
"ctl.png": (512, 512),
|
||||
"face.png": (768, 768),
|
||||
"ref.png": (1024, 1024),
|
||||
}
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(recall_module, "process_controlnet_image", lambda *a, **kw: None)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/recall/default",
|
||||
json={
|
||||
"positive_prompt": "a cat",
|
||||
"negative_prompt": "blurry",
|
||||
"model": "my-model",
|
||||
"steps": 30,
|
||||
"cfg_scale": 7.5,
|
||||
"width": 512,
|
||||
"height": 512,
|
||||
"seed": 42,
|
||||
"loras": [{"model_name": "detail-lora", "weight": 0.6}],
|
||||
"control_layers": [{"model_name": "canny", "image_name": "ctl.png", "weight": 0.75}],
|
||||
"ip_adapters": [
|
||||
{"model_name": "ipa-face", "image_name": "face.png", "weight": 0.5, "method": "composition"}
|
||||
],
|
||||
"reference_images": [{"image_name": "ref.png"}],
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
params = response.json()["parameters"]
|
||||
|
||||
# Core fields
|
||||
assert params["positive_prompt"] == "a cat"
|
||||
assert params["negative_prompt"] == "blurry"
|
||||
assert params["model"] == "key-main"
|
||||
assert params["steps"] == 30
|
||||
assert params["seed"] == 42
|
||||
|
||||
# Collections
|
||||
assert params["loras"] == [{"model_key": "key-lora", "weight": 0.6, "is_enabled": True}]
|
||||
assert params["control_layers"][0]["model_key"] == "key-canny"
|
||||
assert params["control_layers"][0]["image"]["image_name"] == "ctl.png"
|
||||
assert params["ip_adapters"][0]["model_key"] == "key-ipa"
|
||||
assert params["ip_adapters"][0]["method"] == "composition"
|
||||
assert params["reference_images"] == [{"image": {"image_name": "ref.png", "width": 1024, "height": 1024}}]
|
||||
|
||||
def test_unresolvable_main_model_drops_from_payload(
|
||||
self, monkeypatch: Any, patched_dependencies: MockApiDependencies, client: TestClient
|
||||
) -> None:
|
||||
"""A model name that doesn't resolve should be scrubbed from the broadcast payload."""
|
||||
monkeypatch.setattr(
|
||||
recall_module,
|
||||
"resolve_model_name_to_key",
|
||||
make_name_to_key_stub({}),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/recall/default",
|
||||
json={"positive_prompt": "x", "model": "ghost-model"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
params = response.json()["parameters"]
|
||||
assert params["positive_prompt"] == "x"
|
||||
assert "model" not in params
|
||||
|
||||
Reference in New Issue
Block a user