Broaden text encoder partial-load recovery (#9034)

This commit is contained in:
Jonathan
2026-04-09 20:09:40 -04:00
committed by GitHub
parent d4c0e631e2
commit ee600973ed
8 changed files with 438 additions and 8 deletions

View File

@@ -0,0 +1,139 @@
from contextlib import contextmanager, nullcontext
from types import SimpleNamespace
from unittest.mock import MagicMock
import torch
from invokeai.app.invocations.compel import SDXLPromptInvocationBase
class FakeClipTextEncoder(torch.nn.Module):
def __init__(self, effective_device: torch.device):
super().__init__()
self.register_parameter("cpu_param", torch.nn.Parameter(torch.ones(1)))
self.register_buffer("active_buffer", torch.ones(1, device=effective_device))
self.dtype = torch.float32
@property
def device(self) -> torch.device:
return torch.device("cpu")
class FakeTokenizer:
pass
class FakeLoadedModel:
def __init__(self, model, config=None):
self._model = model
self.config = config
@contextmanager
def model_on_device(self):
yield (None, self._model)
def __enter__(self):
return self._model
def __exit__(self, exc_type, exc, tb):
return False
class FakeCompel:
last_init_device: torch.device | None = None
def __init__(self, *args, device: torch.device, **kwargs):
del args, kwargs
FakeCompel.last_init_device = device
self.conditioning_provider = SimpleNamespace(
get_pooled_embeddings=lambda prompts: torch.ones((len(prompts), 4), dtype=torch.float32)
)
@staticmethod
def parse_prompt_string(prompt: str) -> str:
return prompt
def build_conditioning_tensor_for_conjunction(self, conjunction: str):
del conjunction
return torch.ones((1, 4, 4), dtype=torch.float32), {}
@contextmanager
def fake_apply_ti(tokenizer, text_encoder, ti_list):
del text_encoder, ti_list
yield tokenizer, object()
def test_sdxl_run_clip_compel_uses_effective_device_for_partially_loaded_model(monkeypatch):
module_path = "invokeai.app.invocations.compel"
effective_device = torch.device("meta")
text_encoder = FakeClipTextEncoder(effective_device=effective_device)
tokenizer = FakeTokenizer()
text_encoder_info = FakeLoadedModel(text_encoder, config=SimpleNamespace(base="sdxl"))
tokenizer_info = FakeLoadedModel(tokenizer)
mock_context = MagicMock()
mock_context.models.load.side_effect = [text_encoder_info, tokenizer_info]
mock_context.config.get.return_value.log_tokenization = False
mock_context.util.signal_progress = MagicMock()
monkeypatch.setattr(f"{module_path}.CLIPTextModel", FakeClipTextEncoder)
monkeypatch.setattr(f"{module_path}.CLIPTextModelWithProjection", FakeClipTextEncoder)
monkeypatch.setattr(f"{module_path}.CLIPTokenizer", FakeTokenizer)
monkeypatch.setattr(f"{module_path}.Compel", FakeCompel)
monkeypatch.setattr(f"{module_path}.generate_ti_list", lambda prompt, base, context: [])
monkeypatch.setattr(f"{module_path}.LayerPatcher.apply_smart_model_patches", lambda **kwargs: nullcontext())
monkeypatch.setattr(f"{module_path}.ModelPatcher.apply_clip_skip", lambda *args, **kwargs: nullcontext())
monkeypatch.setattr(f"{module_path}.ModelPatcher.apply_ti", fake_apply_ti)
base = SDXLPromptInvocationBase()
cond, pooled = base.run_clip_compel(
context=mock_context,
clip_field=SimpleNamespace(
text_encoder=SimpleNamespace(), tokenizer=SimpleNamespace(), loras=[], skipped_layers=0
),
prompt="test prompt",
get_pooled=False,
lora_prefix="lora_te1_",
zero_on_empty=False,
)
assert FakeCompel.last_init_device == effective_device
assert cond.shape == (1, 4, 4)
assert pooled is None
def test_sdxl_run_clip_compel_uses_cpu_for_fully_cpu_model(monkeypatch):
module_path = "invokeai.app.invocations.compel"
text_encoder = FakeClipTextEncoder(effective_device=torch.device("cpu"))
tokenizer = FakeTokenizer()
text_encoder_info = FakeLoadedModel(text_encoder, config=SimpleNamespace(base="sdxl"))
tokenizer_info = FakeLoadedModel(tokenizer)
mock_context = MagicMock()
mock_context.models.load.side_effect = [text_encoder_info, tokenizer_info]
mock_context.config.get.return_value.log_tokenization = False
mock_context.util.signal_progress = MagicMock()
monkeypatch.setattr(f"{module_path}.CLIPTextModel", FakeClipTextEncoder)
monkeypatch.setattr(f"{module_path}.CLIPTextModelWithProjection", FakeClipTextEncoder)
monkeypatch.setattr(f"{module_path}.CLIPTokenizer", FakeTokenizer)
monkeypatch.setattr(f"{module_path}.Compel", FakeCompel)
monkeypatch.setattr(f"{module_path}.generate_ti_list", lambda prompt, base, context: [])
monkeypatch.setattr(f"{module_path}.LayerPatcher.apply_smart_model_patches", lambda **kwargs: nullcontext())
monkeypatch.setattr(f"{module_path}.ModelPatcher.apply_clip_skip", lambda *args, **kwargs: nullcontext())
monkeypatch.setattr(f"{module_path}.ModelPatcher.apply_ti", fake_apply_ti)
base = SDXLPromptInvocationBase()
base.run_clip_compel(
context=mock_context,
clip_field=SimpleNamespace(
text_encoder=SimpleNamespace(), tokenizer=SimpleNamespace(), loras=[], skipped_layers=0
),
prompt="test prompt",
get_pooled=False,
lora_prefix="lora_te1_",
zero_on_empty=False,
)
assert FakeCompel.last_init_device == torch.device("cpu")

View File

@@ -0,0 +1,154 @@
from contextlib import contextmanager, nullcontext
from types import SimpleNamespace
from unittest.mock import MagicMock
import torch
from invokeai.app.invocations.sd3_text_encoder import Sd3TextEncoderInvocation
from invokeai.backend.model_manager.taxonomy import ModelFormat
class FakeSd3ClipTextEncoder(torch.nn.Module):
def __init__(self, effective_device: torch.device):
super().__init__()
self.register_parameter("cpu_param", torch.nn.Parameter(torch.ones(1)))
self.register_buffer("active_buffer", torch.ones(1, device=effective_device))
self.dtype = torch.float32
self.forward_input_device: torch.device | None = None
@property
def device(self) -> torch.device:
return torch.device("cpu")
def forward(self, input_ids: torch.Tensor, output_hidden_states: bool = False):
assert output_hidden_states
self.forward_input_device = input_ids.device
hidden = input_ids.unsqueeze(-1).float()
return SimpleNamespace(hidden_states=[hidden, hidden + 1], __getitem__=lambda self, idx: hidden)
class FakeClipOutput(SimpleNamespace):
def __getitem__(self, idx):
del idx
return self.hidden_states[-1]
class FakeClipTokenizer:
def __call__(self, prompt, padding, max_length=None, truncation=None, return_tensors=None):
del prompt, padding, max_length, truncation, return_tensors
return SimpleNamespace(input_ids=torch.tensor([[1, 2, 3]], dtype=torch.long))
def batch_decode(self, input_ids):
del input_ids
return ["decoded"]
class FakeSd3T5Encoder(torch.nn.Module):
def __init__(self, effective_device: torch.device):
super().__init__()
self.register_parameter("cpu_param", torch.nn.Parameter(torch.ones(1)))
self.register_buffer("active_buffer", torch.ones(1, device=effective_device))
self.forward_input_device: torch.device | None = None
@property
def device(self) -> torch.device:
return torch.device("cpu")
def forward(self, input_ids: torch.Tensor):
self.forward_input_device = input_ids.device
hidden = input_ids.unsqueeze(-1).float()
return (hidden,)
class FakeT5Tokenizer:
def __call__(self, prompt, padding, max_length=None, truncation=None, add_special_tokens=None, return_tensors=None):
del prompt, padding, max_length, truncation, add_special_tokens, return_tensors
return SimpleNamespace(input_ids=torch.tensor([[1, 2, 3]], dtype=torch.long))
def batch_decode(self, input_ids):
del input_ids
return ["decoded"]
class FakeLoadedModel:
def __init__(self, model, config=None):
self._model = model
self.config = config
@contextmanager
def model_on_device(self):
yield (None, self._model)
def __enter__(self):
return self._model
def __exit__(self, exc_type, exc, tb):
return False
def test_sd3_clip_encode_uses_effective_device(monkeypatch):
module_path = "invokeai.app.invocations.sd3_text_encoder"
effective_device = torch.device("meta")
text_encoder = FakeSd3ClipTextEncoder(effective_device)
tokenizer = FakeClipTokenizer()
def forward(input_ids: torch.Tensor, output_hidden_states: bool = False):
assert output_hidden_states
text_encoder.forward_input_device = input_ids.device
hidden = input_ids.unsqueeze(-1).float()
return FakeClipOutput(hidden_states=[hidden, hidden + 1])
text_encoder.forward = forward # type: ignore[method-assign]
mock_context = MagicMock()
mock_context.models.load.side_effect = [
FakeLoadedModel(text_encoder, config=SimpleNamespace(format=ModelFormat.Diffusers)),
FakeLoadedModel(tokenizer),
]
mock_context.util.signal_progress = MagicMock()
monkeypatch.setattr(f"{module_path}.CLIPTextModel", FakeSd3ClipTextEncoder)
monkeypatch.setattr(f"{module_path}.CLIPTextModelWithProjection", FakeSd3ClipTextEncoder)
monkeypatch.setattr(f"{module_path}.CLIPTokenizer", FakeClipTokenizer)
monkeypatch.setattr(f"{module_path}.LayerPatcher.apply_smart_model_patches", lambda **kwargs: nullcontext())
invocation = Sd3TextEncoderInvocation.model_construct(
clip_l=SimpleNamespace(text_encoder=SimpleNamespace(), tokenizer=SimpleNamespace(), loras=[]),
clip_g=SimpleNamespace(text_encoder=SimpleNamespace(), tokenizer=SimpleNamespace(), loras=[]),
t5_encoder=None,
prompt="test prompt",
)
invocation._clip_encode(
context=mock_context,
clip_model=SimpleNamespace(text_encoder=SimpleNamespace(), tokenizer=SimpleNamespace(), loras=[]),
)
assert text_encoder.forward_input_device == effective_device
def test_sd3_t5_encode_uses_effective_device(monkeypatch):
module_path = "invokeai.app.invocations.sd3_text_encoder"
effective_device = torch.device("meta")
text_encoder = FakeSd3T5Encoder(effective_device)
tokenizer = FakeT5Tokenizer()
mock_context = MagicMock()
mock_context.models.load.side_effect = [FakeLoadedModel(text_encoder), FakeLoadedModel(tokenizer)]
mock_context.util.signal_progress = MagicMock()
mock_context.logger.warning = MagicMock()
monkeypatch.setattr(f"{module_path}.T5EncoderModel", FakeSd3T5Encoder)
monkeypatch.setattr(f"{module_path}.T5Tokenizer", FakeT5Tokenizer)
monkeypatch.setattr(f"{module_path}.T5TokenizerFast", FakeT5Tokenizer)
invocation = Sd3TextEncoderInvocation.model_construct(
clip_l=SimpleNamespace(text_encoder=SimpleNamespace(), tokenizer=SimpleNamespace(), loras=[]),
clip_g=SimpleNamespace(text_encoder=SimpleNamespace(), tokenizer=SimpleNamespace(), loras=[]),
t5_encoder=SimpleNamespace(text_encoder=SimpleNamespace(), tokenizer=SimpleNamespace()),
prompt="test prompt",
)
invocation._t5_encode(mock_context, max_seq_len=16)
assert text_encoder.forward_input_device == effective_device

View File

@@ -0,0 +1,45 @@
import torch
from invokeai.backend.flux.modules.conditioner import HFEncoder
class FakeTokenizer:
def __call__(
self,
text,
truncation,
max_length,
return_length,
return_overflowing_tokens,
padding,
return_tensors,
):
del text, truncation, max_length, return_length, return_overflowing_tokens, padding, return_tensors
return {"input_ids": torch.tensor([[1, 2, 3]], dtype=torch.long)}
class FakeEncoderOutput(dict):
pass
class FakePartiallyLoadedEncoder(torch.nn.Module):
def __init__(self, effective_device: torch.device):
super().__init__()
self.register_parameter("cpu_param", torch.nn.Parameter(torch.ones(1)))
self.register_buffer("active_buffer", torch.ones(1, device=effective_device))
self.forward_input_device: torch.device | None = None
def forward(self, input_ids: torch.Tensor, attention_mask=None, output_hidden_states: bool = False):
del attention_mask, output_hidden_states
self.forward_input_device = input_ids.device
return FakeEncoderOutput(pooler_output=torch.ones((1, 4), dtype=torch.float32))
def test_hf_encoder_uses_effective_device_for_partially_loaded_models():
effective_device = torch.device("meta")
encoder = FakePartiallyLoadedEncoder(effective_device=effective_device)
hf_encoder = HFEncoder(encoder=encoder, tokenizer=FakeTokenizer(), is_clip=True, max_length=77)
hf_encoder(["test prompt"])
assert encoder.forward_input_device == effective_device

View File

@@ -0,0 +1,82 @@
import pytest
import torch
from invokeai.backend.model_manager.load.load_base import LoadedModelWithoutConfig
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_only_full_load import (
CachedModelOnlyFullLoad,
)
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import (
CachedModelWithPartialLoad,
)
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import (
apply_custom_layers_to_model,
)
class ModelWithRequiredScale(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(4, 4)
self.scale = torch.nn.Parameter(torch.ones(4))
class FakeCache:
def __init__(self):
self.lock_calls = 0
self.unlock_calls = 0
def lock(self, cache_record: CacheRecord, working_mem_bytes: int | None) -> None:
del cache_record, working_mem_bytes
self.lock_calls += 1
def unlock(self, cache_record: CacheRecord) -> None:
del cache_record
self.unlock_calls += 1
def test_model_on_device_repairs_required_tensors_for_partial_models():
model = ModelWithRequiredScale()
apply_custom_layers_to_model(model, device_autocasting_enabled=True)
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device("meta"), keep_ram_copy=False)
loaded_model = LoadedModelWithoutConfig(
cache_record=CacheRecord(key="test", cached_model=cached_model), cache=FakeCache()
)
with loaded_model.model_on_device():
assert model.scale.device.type == "meta"
assert all(param.device.type == "cpu" for param in model.linear.parameters())
def test_model_on_device_leaves_full_load_models_unchanged():
model = torch.nn.Linear(4, 4)
cached_model = CachedModelOnlyFullLoad(
model=model, compute_device=torch.device("meta"), total_bytes=1, keep_ram_copy=False
)
loaded_model = LoadedModelWithoutConfig(
cache_record=CacheRecord(key="test", cached_model=cached_model), cache=FakeCache()
)
with loaded_model.model_on_device() as (_, returned_model):
assert returned_model is model
assert all(param.device.type == "cpu" for param in model.parameters())
def test_enter_unlocks_if_repair_raises():
class BrokenCachedModel(CachedModelWithPartialLoad):
def repair_required_tensors_on_compute_device(self) -> int:
raise RuntimeError("repair failed")
model = ModelWithRequiredScale()
apply_custom_layers_to_model(model, device_autocasting_enabled=True)
cached_model = BrokenCachedModel(model=model, compute_device=torch.device("meta"), keep_ram_copy=False)
fake_cache = FakeCache()
loaded_model = LoadedModelWithoutConfig(
cache_record=CacheRecord(key="test", cached_model=cached_model), cache=fake_cache
)
with pytest.raises(RuntimeError, match="repair failed"):
loaded_model.__enter__()
assert fake_cache.lock_calls == 1
assert fake_cache.unlock_calls == 1