mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Broaden text encoder partial-load recovery (#9034)
This commit is contained in:
139
tests/app/invocations/test_compel.py
Normal file
139
tests/app/invocations/test_compel.py
Normal file
@@ -0,0 +1,139 @@
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.app.invocations.compel import SDXLPromptInvocationBase
|
||||
|
||||
|
||||
class FakeClipTextEncoder(torch.nn.Module):
|
||||
def __init__(self, effective_device: torch.device):
|
||||
super().__init__()
|
||||
self.register_parameter("cpu_param", torch.nn.Parameter(torch.ones(1)))
|
||||
self.register_buffer("active_buffer", torch.ones(1, device=effective_device))
|
||||
self.dtype = torch.float32
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
class FakeTokenizer:
|
||||
pass
|
||||
|
||||
|
||||
class FakeLoadedModel:
|
||||
def __init__(self, model, config=None):
|
||||
self._model = model
|
||||
self.config = config
|
||||
|
||||
@contextmanager
|
||||
def model_on_device(self):
|
||||
yield (None, self._model)
|
||||
|
||||
def __enter__(self):
|
||||
return self._model
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
|
||||
class FakeCompel:
|
||||
last_init_device: torch.device | None = None
|
||||
|
||||
def __init__(self, *args, device: torch.device, **kwargs):
|
||||
del args, kwargs
|
||||
FakeCompel.last_init_device = device
|
||||
self.conditioning_provider = SimpleNamespace(
|
||||
get_pooled_embeddings=lambda prompts: torch.ones((len(prompts), 4), dtype=torch.float32)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def parse_prompt_string(prompt: str) -> str:
|
||||
return prompt
|
||||
|
||||
def build_conditioning_tensor_for_conjunction(self, conjunction: str):
|
||||
del conjunction
|
||||
return torch.ones((1, 4, 4), dtype=torch.float32), {}
|
||||
|
||||
|
||||
@contextmanager
|
||||
def fake_apply_ti(tokenizer, text_encoder, ti_list):
|
||||
del text_encoder, ti_list
|
||||
yield tokenizer, object()
|
||||
|
||||
|
||||
def test_sdxl_run_clip_compel_uses_effective_device_for_partially_loaded_model(monkeypatch):
|
||||
module_path = "invokeai.app.invocations.compel"
|
||||
effective_device = torch.device("meta")
|
||||
text_encoder = FakeClipTextEncoder(effective_device=effective_device)
|
||||
tokenizer = FakeTokenizer()
|
||||
text_encoder_info = FakeLoadedModel(text_encoder, config=SimpleNamespace(base="sdxl"))
|
||||
tokenizer_info = FakeLoadedModel(tokenizer)
|
||||
|
||||
mock_context = MagicMock()
|
||||
mock_context.models.load.side_effect = [text_encoder_info, tokenizer_info]
|
||||
mock_context.config.get.return_value.log_tokenization = False
|
||||
mock_context.util.signal_progress = MagicMock()
|
||||
|
||||
monkeypatch.setattr(f"{module_path}.CLIPTextModel", FakeClipTextEncoder)
|
||||
monkeypatch.setattr(f"{module_path}.CLIPTextModelWithProjection", FakeClipTextEncoder)
|
||||
monkeypatch.setattr(f"{module_path}.CLIPTokenizer", FakeTokenizer)
|
||||
monkeypatch.setattr(f"{module_path}.Compel", FakeCompel)
|
||||
monkeypatch.setattr(f"{module_path}.generate_ti_list", lambda prompt, base, context: [])
|
||||
monkeypatch.setattr(f"{module_path}.LayerPatcher.apply_smart_model_patches", lambda **kwargs: nullcontext())
|
||||
monkeypatch.setattr(f"{module_path}.ModelPatcher.apply_clip_skip", lambda *args, **kwargs: nullcontext())
|
||||
monkeypatch.setattr(f"{module_path}.ModelPatcher.apply_ti", fake_apply_ti)
|
||||
|
||||
base = SDXLPromptInvocationBase()
|
||||
cond, pooled = base.run_clip_compel(
|
||||
context=mock_context,
|
||||
clip_field=SimpleNamespace(
|
||||
text_encoder=SimpleNamespace(), tokenizer=SimpleNamespace(), loras=[], skipped_layers=0
|
||||
),
|
||||
prompt="test prompt",
|
||||
get_pooled=False,
|
||||
lora_prefix="lora_te1_",
|
||||
zero_on_empty=False,
|
||||
)
|
||||
|
||||
assert FakeCompel.last_init_device == effective_device
|
||||
assert cond.shape == (1, 4, 4)
|
||||
assert pooled is None
|
||||
|
||||
|
||||
def test_sdxl_run_clip_compel_uses_cpu_for_fully_cpu_model(monkeypatch):
|
||||
module_path = "invokeai.app.invocations.compel"
|
||||
text_encoder = FakeClipTextEncoder(effective_device=torch.device("cpu"))
|
||||
tokenizer = FakeTokenizer()
|
||||
text_encoder_info = FakeLoadedModel(text_encoder, config=SimpleNamespace(base="sdxl"))
|
||||
tokenizer_info = FakeLoadedModel(tokenizer)
|
||||
|
||||
mock_context = MagicMock()
|
||||
mock_context.models.load.side_effect = [text_encoder_info, tokenizer_info]
|
||||
mock_context.config.get.return_value.log_tokenization = False
|
||||
mock_context.util.signal_progress = MagicMock()
|
||||
|
||||
monkeypatch.setattr(f"{module_path}.CLIPTextModel", FakeClipTextEncoder)
|
||||
monkeypatch.setattr(f"{module_path}.CLIPTextModelWithProjection", FakeClipTextEncoder)
|
||||
monkeypatch.setattr(f"{module_path}.CLIPTokenizer", FakeTokenizer)
|
||||
monkeypatch.setattr(f"{module_path}.Compel", FakeCompel)
|
||||
monkeypatch.setattr(f"{module_path}.generate_ti_list", lambda prompt, base, context: [])
|
||||
monkeypatch.setattr(f"{module_path}.LayerPatcher.apply_smart_model_patches", lambda **kwargs: nullcontext())
|
||||
monkeypatch.setattr(f"{module_path}.ModelPatcher.apply_clip_skip", lambda *args, **kwargs: nullcontext())
|
||||
monkeypatch.setattr(f"{module_path}.ModelPatcher.apply_ti", fake_apply_ti)
|
||||
|
||||
base = SDXLPromptInvocationBase()
|
||||
base.run_clip_compel(
|
||||
context=mock_context,
|
||||
clip_field=SimpleNamespace(
|
||||
text_encoder=SimpleNamespace(), tokenizer=SimpleNamespace(), loras=[], skipped_layers=0
|
||||
),
|
||||
prompt="test prompt",
|
||||
get_pooled=False,
|
||||
lora_prefix="lora_te1_",
|
||||
zero_on_empty=False,
|
||||
)
|
||||
|
||||
assert FakeCompel.last_init_device == torch.device("cpu")
|
||||
154
tests/app/invocations/test_sd3_text_encoder.py
Normal file
154
tests/app/invocations/test_sd3_text_encoder.py
Normal file
@@ -0,0 +1,154 @@
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.app.invocations.sd3_text_encoder import Sd3TextEncoderInvocation
|
||||
from invokeai.backend.model_manager.taxonomy import ModelFormat
|
||||
|
||||
|
||||
class FakeSd3ClipTextEncoder(torch.nn.Module):
|
||||
def __init__(self, effective_device: torch.device):
|
||||
super().__init__()
|
||||
self.register_parameter("cpu_param", torch.nn.Parameter(torch.ones(1)))
|
||||
self.register_buffer("active_buffer", torch.ones(1, device=effective_device))
|
||||
self.dtype = torch.float32
|
||||
self.forward_input_device: torch.device | None = None
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return torch.device("cpu")
|
||||
|
||||
def forward(self, input_ids: torch.Tensor, output_hidden_states: bool = False):
|
||||
assert output_hidden_states
|
||||
self.forward_input_device = input_ids.device
|
||||
hidden = input_ids.unsqueeze(-1).float()
|
||||
return SimpleNamespace(hidden_states=[hidden, hidden + 1], __getitem__=lambda self, idx: hidden)
|
||||
|
||||
|
||||
class FakeClipOutput(SimpleNamespace):
|
||||
def __getitem__(self, idx):
|
||||
del idx
|
||||
return self.hidden_states[-1]
|
||||
|
||||
|
||||
class FakeClipTokenizer:
|
||||
def __call__(self, prompt, padding, max_length=None, truncation=None, return_tensors=None):
|
||||
del prompt, padding, max_length, truncation, return_tensors
|
||||
return SimpleNamespace(input_ids=torch.tensor([[1, 2, 3]], dtype=torch.long))
|
||||
|
||||
def batch_decode(self, input_ids):
|
||||
del input_ids
|
||||
return ["decoded"]
|
||||
|
||||
|
||||
class FakeSd3T5Encoder(torch.nn.Module):
|
||||
def __init__(self, effective_device: torch.device):
|
||||
super().__init__()
|
||||
self.register_parameter("cpu_param", torch.nn.Parameter(torch.ones(1)))
|
||||
self.register_buffer("active_buffer", torch.ones(1, device=effective_device))
|
||||
self.forward_input_device: torch.device | None = None
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return torch.device("cpu")
|
||||
|
||||
def forward(self, input_ids: torch.Tensor):
|
||||
self.forward_input_device = input_ids.device
|
||||
hidden = input_ids.unsqueeze(-1).float()
|
||||
return (hidden,)
|
||||
|
||||
|
||||
class FakeT5Tokenizer:
|
||||
def __call__(self, prompt, padding, max_length=None, truncation=None, add_special_tokens=None, return_tensors=None):
|
||||
del prompt, padding, max_length, truncation, add_special_tokens, return_tensors
|
||||
return SimpleNamespace(input_ids=torch.tensor([[1, 2, 3]], dtype=torch.long))
|
||||
|
||||
def batch_decode(self, input_ids):
|
||||
del input_ids
|
||||
return ["decoded"]
|
||||
|
||||
|
||||
class FakeLoadedModel:
|
||||
def __init__(self, model, config=None):
|
||||
self._model = model
|
||||
self.config = config
|
||||
|
||||
@contextmanager
|
||||
def model_on_device(self):
|
||||
yield (None, self._model)
|
||||
|
||||
def __enter__(self):
|
||||
return self._model
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
|
||||
def test_sd3_clip_encode_uses_effective_device(monkeypatch):
|
||||
module_path = "invokeai.app.invocations.sd3_text_encoder"
|
||||
effective_device = torch.device("meta")
|
||||
text_encoder = FakeSd3ClipTextEncoder(effective_device)
|
||||
tokenizer = FakeClipTokenizer()
|
||||
|
||||
def forward(input_ids: torch.Tensor, output_hidden_states: bool = False):
|
||||
assert output_hidden_states
|
||||
text_encoder.forward_input_device = input_ids.device
|
||||
hidden = input_ids.unsqueeze(-1).float()
|
||||
return FakeClipOutput(hidden_states=[hidden, hidden + 1])
|
||||
|
||||
text_encoder.forward = forward # type: ignore[method-assign]
|
||||
|
||||
mock_context = MagicMock()
|
||||
mock_context.models.load.side_effect = [
|
||||
FakeLoadedModel(text_encoder, config=SimpleNamespace(format=ModelFormat.Diffusers)),
|
||||
FakeLoadedModel(tokenizer),
|
||||
]
|
||||
mock_context.util.signal_progress = MagicMock()
|
||||
|
||||
monkeypatch.setattr(f"{module_path}.CLIPTextModel", FakeSd3ClipTextEncoder)
|
||||
monkeypatch.setattr(f"{module_path}.CLIPTextModelWithProjection", FakeSd3ClipTextEncoder)
|
||||
monkeypatch.setattr(f"{module_path}.CLIPTokenizer", FakeClipTokenizer)
|
||||
monkeypatch.setattr(f"{module_path}.LayerPatcher.apply_smart_model_patches", lambda **kwargs: nullcontext())
|
||||
|
||||
invocation = Sd3TextEncoderInvocation.model_construct(
|
||||
clip_l=SimpleNamespace(text_encoder=SimpleNamespace(), tokenizer=SimpleNamespace(), loras=[]),
|
||||
clip_g=SimpleNamespace(text_encoder=SimpleNamespace(), tokenizer=SimpleNamespace(), loras=[]),
|
||||
t5_encoder=None,
|
||||
prompt="test prompt",
|
||||
)
|
||||
|
||||
invocation._clip_encode(
|
||||
context=mock_context,
|
||||
clip_model=SimpleNamespace(text_encoder=SimpleNamespace(), tokenizer=SimpleNamespace(), loras=[]),
|
||||
)
|
||||
|
||||
assert text_encoder.forward_input_device == effective_device
|
||||
|
||||
|
||||
def test_sd3_t5_encode_uses_effective_device(monkeypatch):
|
||||
module_path = "invokeai.app.invocations.sd3_text_encoder"
|
||||
effective_device = torch.device("meta")
|
||||
text_encoder = FakeSd3T5Encoder(effective_device)
|
||||
tokenizer = FakeT5Tokenizer()
|
||||
|
||||
mock_context = MagicMock()
|
||||
mock_context.models.load.side_effect = [FakeLoadedModel(text_encoder), FakeLoadedModel(tokenizer)]
|
||||
mock_context.util.signal_progress = MagicMock()
|
||||
mock_context.logger.warning = MagicMock()
|
||||
|
||||
monkeypatch.setattr(f"{module_path}.T5EncoderModel", FakeSd3T5Encoder)
|
||||
monkeypatch.setattr(f"{module_path}.T5Tokenizer", FakeT5Tokenizer)
|
||||
monkeypatch.setattr(f"{module_path}.T5TokenizerFast", FakeT5Tokenizer)
|
||||
|
||||
invocation = Sd3TextEncoderInvocation.model_construct(
|
||||
clip_l=SimpleNamespace(text_encoder=SimpleNamespace(), tokenizer=SimpleNamespace(), loras=[]),
|
||||
clip_g=SimpleNamespace(text_encoder=SimpleNamespace(), tokenizer=SimpleNamespace(), loras=[]),
|
||||
t5_encoder=SimpleNamespace(text_encoder=SimpleNamespace(), tokenizer=SimpleNamespace()),
|
||||
prompt="test prompt",
|
||||
)
|
||||
|
||||
invocation._t5_encode(mock_context, max_seq_len=16)
|
||||
|
||||
assert text_encoder.forward_input_device == effective_device
|
||||
45
tests/backend/flux/modules/test_conditioner.py
Normal file
45
tests/backend/flux/modules/test_conditioner.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.flux.modules.conditioner import HFEncoder
|
||||
|
||||
|
||||
class FakeTokenizer:
|
||||
def __call__(
|
||||
self,
|
||||
text,
|
||||
truncation,
|
||||
max_length,
|
||||
return_length,
|
||||
return_overflowing_tokens,
|
||||
padding,
|
||||
return_tensors,
|
||||
):
|
||||
del text, truncation, max_length, return_length, return_overflowing_tokens, padding, return_tensors
|
||||
return {"input_ids": torch.tensor([[1, 2, 3]], dtype=torch.long)}
|
||||
|
||||
|
||||
class FakeEncoderOutput(dict):
|
||||
pass
|
||||
|
||||
|
||||
class FakePartiallyLoadedEncoder(torch.nn.Module):
|
||||
def __init__(self, effective_device: torch.device):
|
||||
super().__init__()
|
||||
self.register_parameter("cpu_param", torch.nn.Parameter(torch.ones(1)))
|
||||
self.register_buffer("active_buffer", torch.ones(1, device=effective_device))
|
||||
self.forward_input_device: torch.device | None = None
|
||||
|
||||
def forward(self, input_ids: torch.Tensor, attention_mask=None, output_hidden_states: bool = False):
|
||||
del attention_mask, output_hidden_states
|
||||
self.forward_input_device = input_ids.device
|
||||
return FakeEncoderOutput(pooler_output=torch.ones((1, 4), dtype=torch.float32))
|
||||
|
||||
|
||||
def test_hf_encoder_uses_effective_device_for_partially_loaded_models():
|
||||
effective_device = torch.device("meta")
|
||||
encoder = FakePartiallyLoadedEncoder(effective_device=effective_device)
|
||||
hf_encoder = HFEncoder(encoder=encoder, tokenizer=FakeTokenizer(), is_clip=True, max_length=77)
|
||||
|
||||
hf_encoder(["test prompt"])
|
||||
|
||||
assert encoder.forward_input_device == effective_device
|
||||
82
tests/backend/model_manager/load/test_loaded_model.py
Normal file
82
tests/backend/model_manager/load/test_loaded_model.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModelWithoutConfig
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
|
||||
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_only_full_load import (
|
||||
CachedModelOnlyFullLoad,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import (
|
||||
CachedModelWithPartialLoad,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import (
|
||||
apply_custom_layers_to_model,
|
||||
)
|
||||
|
||||
|
||||
class ModelWithRequiredScale(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(4, 4)
|
||||
self.scale = torch.nn.Parameter(torch.ones(4))
|
||||
|
||||
|
||||
class FakeCache:
|
||||
def __init__(self):
|
||||
self.lock_calls = 0
|
||||
self.unlock_calls = 0
|
||||
|
||||
def lock(self, cache_record: CacheRecord, working_mem_bytes: int | None) -> None:
|
||||
del cache_record, working_mem_bytes
|
||||
self.lock_calls += 1
|
||||
|
||||
def unlock(self, cache_record: CacheRecord) -> None:
|
||||
del cache_record
|
||||
self.unlock_calls += 1
|
||||
|
||||
|
||||
def test_model_on_device_repairs_required_tensors_for_partial_models():
|
||||
model = ModelWithRequiredScale()
|
||||
apply_custom_layers_to_model(model, device_autocasting_enabled=True)
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device("meta"), keep_ram_copy=False)
|
||||
loaded_model = LoadedModelWithoutConfig(
|
||||
cache_record=CacheRecord(key="test", cached_model=cached_model), cache=FakeCache()
|
||||
)
|
||||
|
||||
with loaded_model.model_on_device():
|
||||
assert model.scale.device.type == "meta"
|
||||
assert all(param.device.type == "cpu" for param in model.linear.parameters())
|
||||
|
||||
|
||||
def test_model_on_device_leaves_full_load_models_unchanged():
|
||||
model = torch.nn.Linear(4, 4)
|
||||
cached_model = CachedModelOnlyFullLoad(
|
||||
model=model, compute_device=torch.device("meta"), total_bytes=1, keep_ram_copy=False
|
||||
)
|
||||
loaded_model = LoadedModelWithoutConfig(
|
||||
cache_record=CacheRecord(key="test", cached_model=cached_model), cache=FakeCache()
|
||||
)
|
||||
|
||||
with loaded_model.model_on_device() as (_, returned_model):
|
||||
assert returned_model is model
|
||||
assert all(param.device.type == "cpu" for param in model.parameters())
|
||||
|
||||
|
||||
def test_enter_unlocks_if_repair_raises():
|
||||
class BrokenCachedModel(CachedModelWithPartialLoad):
|
||||
def repair_required_tensors_on_compute_device(self) -> int:
|
||||
raise RuntimeError("repair failed")
|
||||
|
||||
model = ModelWithRequiredScale()
|
||||
apply_custom_layers_to_model(model, device_autocasting_enabled=True)
|
||||
cached_model = BrokenCachedModel(model=model, compute_device=torch.device("meta"), keep_ram_copy=False)
|
||||
fake_cache = FakeCache()
|
||||
loaded_model = LoadedModelWithoutConfig(
|
||||
cache_record=CacheRecord(key="test", cached_model=cached_model), cache=fake_cache
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="repair failed"):
|
||||
loaded_model.__enter__()
|
||||
|
||||
assert fake_cache.lock_calls == 1
|
||||
assert fake_cache.unlock_calls == 1
|
||||
Reference in New Issue
Block a user