From ee600973ed42a5aa960cd202caa867edfd3eadb7 Mon Sep 17 00:00:00 2001 From: Jonathan <34005131+JPPhoto@users.noreply.github.com> Date: Thu, 9 Apr 2026 20:09:40 -0400 Subject: [PATCH] Broaden text encoder partial-load recovery (#9034) --- invokeai/app/invocations/compel.py | 5 +- invokeai/app/invocations/sd3_text_encoder.py | 9 +- invokeai/backend/flux/modules/conditioner.py | 4 +- .../backend/model_manager/load/load_base.py | 8 +- tests/app/invocations/test_compel.py | 139 ++++++++++++++++ .../app/invocations/test_sd3_text_encoder.py | 154 ++++++++++++++++++ .../backend/flux/modules/test_conditioner.py | 45 +++++ .../model_manager/load/test_loaded_model.py | 82 ++++++++++ 8 files changed, 438 insertions(+), 8 deletions(-) create mode 100644 tests/app/invocations/test_compel.py create mode 100644 tests/app/invocations/test_sd3_text_encoder.py create mode 100644 tests/backend/flux/modules/test_conditioner.py create mode 100644 tests/backend/model_manager/load/test_loaded_model.py diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 5ce88145ff..0ff6be969f 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -19,6 +19,7 @@ from invokeai.app.invocations.model import CLIPField from invokeai.app.invocations.primitives import ConditioningOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.ti_utils import generate_ti_list +from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device from invokeai.backend.model_patcher import ModelPatcher from invokeai.backend.patches.layer_patcher import LayerPatcher from invokeai.backend.patches.model_patch_raw import ModelPatchRaw @@ -103,7 +104,7 @@ class CompelInvocation(BaseInvocation): textual_inversion_manager=ti_manager, dtype_for_device_getter=TorchDevice.choose_torch_dtype, truncate_long_prompts=False, - device=text_encoder.device, # Use the device the model is actually on + device=get_effective_device(text_encoder), split_long_text_mode=SplitLongTextMode.SENTENCES, ) @@ -212,7 +213,7 @@ class SDXLPromptInvocationBase: truncate_long_prompts=False, # TODO: returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip requires_pooled=get_pooled, - device=text_encoder.device, # Use the device the model is actually on + device=get_effective_device(text_encoder), split_long_text_mode=SplitLongTextMode.SENTENCES, ) diff --git a/invokeai/app/invocations/sd3_text_encoder.py b/invokeai/app/invocations/sd3_text_encoder.py index 24647c9cfc..58880f9a28 100644 --- a/invokeai/app/invocations/sd3_text_encoder.py +++ b/invokeai/app/invocations/sd3_text_encoder.py @@ -16,6 +16,7 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField from invokeai.app.invocations.model import CLIPField, T5EncoderField from invokeai.app.invocations.primitives import SD3ConditioningOutput from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device from invokeai.backend.model_manager.taxonomy import ModelFormat from invokeai.backend.patches.layer_patcher import LayerPatcher from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX @@ -103,6 +104,7 @@ class Sd3TextEncoderInvocation(BaseInvocation): context.util.signal_progress("Running T5 encoder") assert isinstance(t5_text_encoder, T5EncoderModel) assert isinstance(t5_tokenizer, (T5Tokenizer, T5TokenizerFast)) + t5_device = get_effective_device(t5_text_encoder) text_inputs = t5_tokenizer( prompt, @@ -125,7 +127,7 @@ class Sd3TextEncoderInvocation(BaseInvocation): f" {max_seq_len} tokens: {removed_text}" ) - prompt_embeds = t5_text_encoder(text_input_ids.to(t5_text_encoder.device))[0] + prompt_embeds = t5_text_encoder(text_input_ids.to(t5_device))[0] assert isinstance(prompt_embeds, torch.Tensor) return prompt_embeds @@ -144,6 +146,7 @@ class Sd3TextEncoderInvocation(BaseInvocation): context.util.signal_progress("Running CLIP encoder") assert isinstance(clip_text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)) assert isinstance(clip_tokenizer, CLIPTokenizer) + clip_device = get_effective_device(clip_text_encoder) clip_text_encoder_config = clip_text_encoder_info.config assert clip_text_encoder_config is not None @@ -187,9 +190,7 @@ class Sd3TextEncoderInvocation(BaseInvocation): "The following part of your input was truncated because CLIP can only handle sequences up to" f" {tokenizer_max_length} tokens: {removed_text}" ) - prompt_embeds = clip_text_encoder( - input_ids=text_input_ids.to(clip_text_encoder.device), output_hidden_states=True - ) + prompt_embeds = clip_text_encoder(input_ids=text_input_ids.to(clip_device), output_hidden_states=True) pooled_prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.hidden_states[-2] diff --git a/invokeai/backend/flux/modules/conditioner.py b/invokeai/backend/flux/modules/conditioner.py index 9deb442929..d48d78cd4a 100644 --- a/invokeai/backend/flux/modules/conditioner.py +++ b/invokeai/backend/flux/modules/conditioner.py @@ -3,6 +3,8 @@ from torch import Tensor, nn from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast +from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device + class HFEncoder(nn.Module): def __init__( @@ -32,7 +34,7 @@ class HFEncoder(nn.Module): ) # Move inputs to the same device as the model to support cpu_only models - model_device = next(self.hf_module.parameters()).device + model_device = get_effective_device(self.hf_module) outputs = self.hf_module( input_ids=batch_encoding["input_ids"].to(model_device), diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py index b972969a68..4609a2e92a 100644 --- a/invokeai/backend/model_manager/load/load_base.py +++ b/invokeai/backend/model_manager/load/load_base.py @@ -58,7 +58,12 @@ class LoadedModelWithoutConfig: def __enter__(self) -> AnyModel: self._cache.lock(self._cache_record, None) - return self.model + try: + self.repair_required_tensors_on_device() + return self.model + except Exception: + self._cache.unlock(self._cache_record) + raise def __exit__(self, *args: Any, **kwargs: Any) -> None: self._cache.unlock(self._cache_record) @@ -74,6 +79,7 @@ class LoadedModelWithoutConfig: """ self._cache.lock(self._cache_record, working_mem_bytes) try: + self.repair_required_tensors_on_device() yield (self._cache_record.cached_model.get_cpu_state_dict(), self._cache_record.cached_model.model) finally: self._cache.unlock(self._cache_record) diff --git a/tests/app/invocations/test_compel.py b/tests/app/invocations/test_compel.py new file mode 100644 index 0000000000..d08f4dfc3c --- /dev/null +++ b/tests/app/invocations/test_compel.py @@ -0,0 +1,139 @@ +from contextlib import contextmanager, nullcontext +from types import SimpleNamespace +from unittest.mock import MagicMock + +import torch + +from invokeai.app.invocations.compel import SDXLPromptInvocationBase + + +class FakeClipTextEncoder(torch.nn.Module): + def __init__(self, effective_device: torch.device): + super().__init__() + self.register_parameter("cpu_param", torch.nn.Parameter(torch.ones(1))) + self.register_buffer("active_buffer", torch.ones(1, device=effective_device)) + self.dtype = torch.float32 + + @property + def device(self) -> torch.device: + return torch.device("cpu") + + +class FakeTokenizer: + pass + + +class FakeLoadedModel: + def __init__(self, model, config=None): + self._model = model + self.config = config + + @contextmanager + def model_on_device(self): + yield (None, self._model) + + def __enter__(self): + return self._model + + def __exit__(self, exc_type, exc, tb): + return False + + +class FakeCompel: + last_init_device: torch.device | None = None + + def __init__(self, *args, device: torch.device, **kwargs): + del args, kwargs + FakeCompel.last_init_device = device + self.conditioning_provider = SimpleNamespace( + get_pooled_embeddings=lambda prompts: torch.ones((len(prompts), 4), dtype=torch.float32) + ) + + @staticmethod + def parse_prompt_string(prompt: str) -> str: + return prompt + + def build_conditioning_tensor_for_conjunction(self, conjunction: str): + del conjunction + return torch.ones((1, 4, 4), dtype=torch.float32), {} + + +@contextmanager +def fake_apply_ti(tokenizer, text_encoder, ti_list): + del text_encoder, ti_list + yield tokenizer, object() + + +def test_sdxl_run_clip_compel_uses_effective_device_for_partially_loaded_model(monkeypatch): + module_path = "invokeai.app.invocations.compel" + effective_device = torch.device("meta") + text_encoder = FakeClipTextEncoder(effective_device=effective_device) + tokenizer = FakeTokenizer() + text_encoder_info = FakeLoadedModel(text_encoder, config=SimpleNamespace(base="sdxl")) + tokenizer_info = FakeLoadedModel(tokenizer) + + mock_context = MagicMock() + mock_context.models.load.side_effect = [text_encoder_info, tokenizer_info] + mock_context.config.get.return_value.log_tokenization = False + mock_context.util.signal_progress = MagicMock() + + monkeypatch.setattr(f"{module_path}.CLIPTextModel", FakeClipTextEncoder) + monkeypatch.setattr(f"{module_path}.CLIPTextModelWithProjection", FakeClipTextEncoder) + monkeypatch.setattr(f"{module_path}.CLIPTokenizer", FakeTokenizer) + monkeypatch.setattr(f"{module_path}.Compel", FakeCompel) + monkeypatch.setattr(f"{module_path}.generate_ti_list", lambda prompt, base, context: []) + monkeypatch.setattr(f"{module_path}.LayerPatcher.apply_smart_model_patches", lambda **kwargs: nullcontext()) + monkeypatch.setattr(f"{module_path}.ModelPatcher.apply_clip_skip", lambda *args, **kwargs: nullcontext()) + monkeypatch.setattr(f"{module_path}.ModelPatcher.apply_ti", fake_apply_ti) + + base = SDXLPromptInvocationBase() + cond, pooled = base.run_clip_compel( + context=mock_context, + clip_field=SimpleNamespace( + text_encoder=SimpleNamespace(), tokenizer=SimpleNamespace(), loras=[], skipped_layers=0 + ), + prompt="test prompt", + get_pooled=False, + lora_prefix="lora_te1_", + zero_on_empty=False, + ) + + assert FakeCompel.last_init_device == effective_device + assert cond.shape == (1, 4, 4) + assert pooled is None + + +def test_sdxl_run_clip_compel_uses_cpu_for_fully_cpu_model(monkeypatch): + module_path = "invokeai.app.invocations.compel" + text_encoder = FakeClipTextEncoder(effective_device=torch.device("cpu")) + tokenizer = FakeTokenizer() + text_encoder_info = FakeLoadedModel(text_encoder, config=SimpleNamespace(base="sdxl")) + tokenizer_info = FakeLoadedModel(tokenizer) + + mock_context = MagicMock() + mock_context.models.load.side_effect = [text_encoder_info, tokenizer_info] + mock_context.config.get.return_value.log_tokenization = False + mock_context.util.signal_progress = MagicMock() + + monkeypatch.setattr(f"{module_path}.CLIPTextModel", FakeClipTextEncoder) + monkeypatch.setattr(f"{module_path}.CLIPTextModelWithProjection", FakeClipTextEncoder) + monkeypatch.setattr(f"{module_path}.CLIPTokenizer", FakeTokenizer) + monkeypatch.setattr(f"{module_path}.Compel", FakeCompel) + monkeypatch.setattr(f"{module_path}.generate_ti_list", lambda prompt, base, context: []) + monkeypatch.setattr(f"{module_path}.LayerPatcher.apply_smart_model_patches", lambda **kwargs: nullcontext()) + monkeypatch.setattr(f"{module_path}.ModelPatcher.apply_clip_skip", lambda *args, **kwargs: nullcontext()) + monkeypatch.setattr(f"{module_path}.ModelPatcher.apply_ti", fake_apply_ti) + + base = SDXLPromptInvocationBase() + base.run_clip_compel( + context=mock_context, + clip_field=SimpleNamespace( + text_encoder=SimpleNamespace(), tokenizer=SimpleNamespace(), loras=[], skipped_layers=0 + ), + prompt="test prompt", + get_pooled=False, + lora_prefix="lora_te1_", + zero_on_empty=False, + ) + + assert FakeCompel.last_init_device == torch.device("cpu") diff --git a/tests/app/invocations/test_sd3_text_encoder.py b/tests/app/invocations/test_sd3_text_encoder.py new file mode 100644 index 0000000000..560dcba43b --- /dev/null +++ b/tests/app/invocations/test_sd3_text_encoder.py @@ -0,0 +1,154 @@ +from contextlib import contextmanager, nullcontext +from types import SimpleNamespace +from unittest.mock import MagicMock + +import torch + +from invokeai.app.invocations.sd3_text_encoder import Sd3TextEncoderInvocation +from invokeai.backend.model_manager.taxonomy import ModelFormat + + +class FakeSd3ClipTextEncoder(torch.nn.Module): + def __init__(self, effective_device: torch.device): + super().__init__() + self.register_parameter("cpu_param", torch.nn.Parameter(torch.ones(1))) + self.register_buffer("active_buffer", torch.ones(1, device=effective_device)) + self.dtype = torch.float32 + self.forward_input_device: torch.device | None = None + + @property + def device(self) -> torch.device: + return torch.device("cpu") + + def forward(self, input_ids: torch.Tensor, output_hidden_states: bool = False): + assert output_hidden_states + self.forward_input_device = input_ids.device + hidden = input_ids.unsqueeze(-1).float() + return SimpleNamespace(hidden_states=[hidden, hidden + 1], __getitem__=lambda self, idx: hidden) + + +class FakeClipOutput(SimpleNamespace): + def __getitem__(self, idx): + del idx + return self.hidden_states[-1] + + +class FakeClipTokenizer: + def __call__(self, prompt, padding, max_length=None, truncation=None, return_tensors=None): + del prompt, padding, max_length, truncation, return_tensors + return SimpleNamespace(input_ids=torch.tensor([[1, 2, 3]], dtype=torch.long)) + + def batch_decode(self, input_ids): + del input_ids + return ["decoded"] + + +class FakeSd3T5Encoder(torch.nn.Module): + def __init__(self, effective_device: torch.device): + super().__init__() + self.register_parameter("cpu_param", torch.nn.Parameter(torch.ones(1))) + self.register_buffer("active_buffer", torch.ones(1, device=effective_device)) + self.forward_input_device: torch.device | None = None + + @property + def device(self) -> torch.device: + return torch.device("cpu") + + def forward(self, input_ids: torch.Tensor): + self.forward_input_device = input_ids.device + hidden = input_ids.unsqueeze(-1).float() + return (hidden,) + + +class FakeT5Tokenizer: + def __call__(self, prompt, padding, max_length=None, truncation=None, add_special_tokens=None, return_tensors=None): + del prompt, padding, max_length, truncation, add_special_tokens, return_tensors + return SimpleNamespace(input_ids=torch.tensor([[1, 2, 3]], dtype=torch.long)) + + def batch_decode(self, input_ids): + del input_ids + return ["decoded"] + + +class FakeLoadedModel: + def __init__(self, model, config=None): + self._model = model + self.config = config + + @contextmanager + def model_on_device(self): + yield (None, self._model) + + def __enter__(self): + return self._model + + def __exit__(self, exc_type, exc, tb): + return False + + +def test_sd3_clip_encode_uses_effective_device(monkeypatch): + module_path = "invokeai.app.invocations.sd3_text_encoder" + effective_device = torch.device("meta") + text_encoder = FakeSd3ClipTextEncoder(effective_device) + tokenizer = FakeClipTokenizer() + + def forward(input_ids: torch.Tensor, output_hidden_states: bool = False): + assert output_hidden_states + text_encoder.forward_input_device = input_ids.device + hidden = input_ids.unsqueeze(-1).float() + return FakeClipOutput(hidden_states=[hidden, hidden + 1]) + + text_encoder.forward = forward # type: ignore[method-assign] + + mock_context = MagicMock() + mock_context.models.load.side_effect = [ + FakeLoadedModel(text_encoder, config=SimpleNamespace(format=ModelFormat.Diffusers)), + FakeLoadedModel(tokenizer), + ] + mock_context.util.signal_progress = MagicMock() + + monkeypatch.setattr(f"{module_path}.CLIPTextModel", FakeSd3ClipTextEncoder) + monkeypatch.setattr(f"{module_path}.CLIPTextModelWithProjection", FakeSd3ClipTextEncoder) + monkeypatch.setattr(f"{module_path}.CLIPTokenizer", FakeClipTokenizer) + monkeypatch.setattr(f"{module_path}.LayerPatcher.apply_smart_model_patches", lambda **kwargs: nullcontext()) + + invocation = Sd3TextEncoderInvocation.model_construct( + clip_l=SimpleNamespace(text_encoder=SimpleNamespace(), tokenizer=SimpleNamespace(), loras=[]), + clip_g=SimpleNamespace(text_encoder=SimpleNamespace(), tokenizer=SimpleNamespace(), loras=[]), + t5_encoder=None, + prompt="test prompt", + ) + + invocation._clip_encode( + context=mock_context, + clip_model=SimpleNamespace(text_encoder=SimpleNamespace(), tokenizer=SimpleNamespace(), loras=[]), + ) + + assert text_encoder.forward_input_device == effective_device + + +def test_sd3_t5_encode_uses_effective_device(monkeypatch): + module_path = "invokeai.app.invocations.sd3_text_encoder" + effective_device = torch.device("meta") + text_encoder = FakeSd3T5Encoder(effective_device) + tokenizer = FakeT5Tokenizer() + + mock_context = MagicMock() + mock_context.models.load.side_effect = [FakeLoadedModel(text_encoder), FakeLoadedModel(tokenizer)] + mock_context.util.signal_progress = MagicMock() + mock_context.logger.warning = MagicMock() + + monkeypatch.setattr(f"{module_path}.T5EncoderModel", FakeSd3T5Encoder) + monkeypatch.setattr(f"{module_path}.T5Tokenizer", FakeT5Tokenizer) + monkeypatch.setattr(f"{module_path}.T5TokenizerFast", FakeT5Tokenizer) + + invocation = Sd3TextEncoderInvocation.model_construct( + clip_l=SimpleNamespace(text_encoder=SimpleNamespace(), tokenizer=SimpleNamespace(), loras=[]), + clip_g=SimpleNamespace(text_encoder=SimpleNamespace(), tokenizer=SimpleNamespace(), loras=[]), + t5_encoder=SimpleNamespace(text_encoder=SimpleNamespace(), tokenizer=SimpleNamespace()), + prompt="test prompt", + ) + + invocation._t5_encode(mock_context, max_seq_len=16) + + assert text_encoder.forward_input_device == effective_device diff --git a/tests/backend/flux/modules/test_conditioner.py b/tests/backend/flux/modules/test_conditioner.py new file mode 100644 index 0000000000..9ccc366c8d --- /dev/null +++ b/tests/backend/flux/modules/test_conditioner.py @@ -0,0 +1,45 @@ +import torch + +from invokeai.backend.flux.modules.conditioner import HFEncoder + + +class FakeTokenizer: + def __call__( + self, + text, + truncation, + max_length, + return_length, + return_overflowing_tokens, + padding, + return_tensors, + ): + del text, truncation, max_length, return_length, return_overflowing_tokens, padding, return_tensors + return {"input_ids": torch.tensor([[1, 2, 3]], dtype=torch.long)} + + +class FakeEncoderOutput(dict): + pass + + +class FakePartiallyLoadedEncoder(torch.nn.Module): + def __init__(self, effective_device: torch.device): + super().__init__() + self.register_parameter("cpu_param", torch.nn.Parameter(torch.ones(1))) + self.register_buffer("active_buffer", torch.ones(1, device=effective_device)) + self.forward_input_device: torch.device | None = None + + def forward(self, input_ids: torch.Tensor, attention_mask=None, output_hidden_states: bool = False): + del attention_mask, output_hidden_states + self.forward_input_device = input_ids.device + return FakeEncoderOutput(pooler_output=torch.ones((1, 4), dtype=torch.float32)) + + +def test_hf_encoder_uses_effective_device_for_partially_loaded_models(): + effective_device = torch.device("meta") + encoder = FakePartiallyLoadedEncoder(effective_device=effective_device) + hf_encoder = HFEncoder(encoder=encoder, tokenizer=FakeTokenizer(), is_clip=True, max_length=77) + + hf_encoder(["test prompt"]) + + assert encoder.forward_input_device == effective_device diff --git a/tests/backend/model_manager/load/test_loaded_model.py b/tests/backend/model_manager/load/test_loaded_model.py new file mode 100644 index 0000000000..b792475b84 --- /dev/null +++ b/tests/backend/model_manager/load/test_loaded_model.py @@ -0,0 +1,82 @@ +import pytest +import torch + +from invokeai.backend.model_manager.load.load_base import LoadedModelWithoutConfig +from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord +from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_only_full_load import ( + CachedModelOnlyFullLoad, +) +from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import ( + CachedModelWithPartialLoad, +) +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import ( + apply_custom_layers_to_model, +) + + +class ModelWithRequiredScale(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4) + self.scale = torch.nn.Parameter(torch.ones(4)) + + +class FakeCache: + def __init__(self): + self.lock_calls = 0 + self.unlock_calls = 0 + + def lock(self, cache_record: CacheRecord, working_mem_bytes: int | None) -> None: + del cache_record, working_mem_bytes + self.lock_calls += 1 + + def unlock(self, cache_record: CacheRecord) -> None: + del cache_record + self.unlock_calls += 1 + + +def test_model_on_device_repairs_required_tensors_for_partial_models(): + model = ModelWithRequiredScale() + apply_custom_layers_to_model(model, device_autocasting_enabled=True) + cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device("meta"), keep_ram_copy=False) + loaded_model = LoadedModelWithoutConfig( + cache_record=CacheRecord(key="test", cached_model=cached_model), cache=FakeCache() + ) + + with loaded_model.model_on_device(): + assert model.scale.device.type == "meta" + assert all(param.device.type == "cpu" for param in model.linear.parameters()) + + +def test_model_on_device_leaves_full_load_models_unchanged(): + model = torch.nn.Linear(4, 4) + cached_model = CachedModelOnlyFullLoad( + model=model, compute_device=torch.device("meta"), total_bytes=1, keep_ram_copy=False + ) + loaded_model = LoadedModelWithoutConfig( + cache_record=CacheRecord(key="test", cached_model=cached_model), cache=FakeCache() + ) + + with loaded_model.model_on_device() as (_, returned_model): + assert returned_model is model + assert all(param.device.type == "cpu" for param in model.parameters()) + + +def test_enter_unlocks_if_repair_raises(): + class BrokenCachedModel(CachedModelWithPartialLoad): + def repair_required_tensors_on_compute_device(self) -> int: + raise RuntimeError("repair failed") + + model = ModelWithRequiredScale() + apply_custom_layers_to_model(model, device_autocasting_enabled=True) + cached_model = BrokenCachedModel(model=model, compute_device=torch.device("meta"), keep_ram_copy=False) + fake_cache = FakeCache() + loaded_model = LoadedModelWithoutConfig( + cache_record=CacheRecord(key="test", cached_model=cached_model), cache=fake_cache + ) + + with pytest.raises(RuntimeError, match="repair failed"): + loaded_model.__enter__() + + assert fake_cache.lock_calls == 1 + assert fake_cache.unlock_calls == 1