mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Merge branch 'main' into external-models
This commit is contained in:
@@ -6,6 +6,7 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
|
||||
from invokeai.app.invocations.model import GlmEncoderField
|
||||
from invokeai.app.invocations.primitives import CogView4ConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
CogView4ConditioningInfo,
|
||||
ConditioningFieldData,
|
||||
@@ -46,10 +47,18 @@ class CogView4TextEncoderInvocation(BaseInvocation):
|
||||
prompt = [self.prompt]
|
||||
|
||||
# TODO(ryand): Add model inputs to the invocation rather than hard-coding.
|
||||
glm_text_encoder_info = context.models.load(self.glm_encoder.text_encoder)
|
||||
with (
|
||||
context.models.load(self.glm_encoder.text_encoder).model_on_device() as (_, glm_text_encoder),
|
||||
glm_text_encoder_info.model_on_device() as (_, glm_text_encoder),
|
||||
context.models.load(self.glm_encoder.tokenizer).model_on_device() as (_, glm_tokenizer),
|
||||
):
|
||||
repaired_tensors = glm_text_encoder_info.repair_required_tensors_on_device()
|
||||
device = get_effective_device(glm_text_encoder)
|
||||
if repaired_tensors > 0:
|
||||
context.logger.warning(
|
||||
f"Recovered {repaired_tensors} required GLM tensor(s) onto {device} after a partial device mismatch."
|
||||
)
|
||||
|
||||
context.util.signal_progress("Running GLM text encoder")
|
||||
assert isinstance(glm_text_encoder, GlmModel)
|
||||
assert isinstance(glm_tokenizer, PreTrainedTokenizerFast)
|
||||
@@ -85,9 +94,7 @@ class CogView4TextEncoderInvocation(BaseInvocation):
|
||||
device=text_input_ids.device,
|
||||
)
|
||||
text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1)
|
||||
prompt_embeds = glm_text_encoder(
|
||||
text_input_ids.to(glm_text_encoder.device), output_hidden_states=True
|
||||
).hidden_states[-2]
|
||||
prompt_embeds = glm_text_encoder(text_input_ids.to(device), output_hidden_states=True).hidden_states[-2]
|
||||
|
||||
assert isinstance(prompt_embeds, torch.Tensor)
|
||||
return prompt_embeds
|
||||
|
||||
@@ -25,6 +25,7 @@ from invokeai.app.invocations.fields import (
|
||||
from invokeai.app.invocations.model import Qwen3EncoderField
|
||||
from invokeai.app.invocations.primitives import FluxConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_T5_PREFIX
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
@@ -100,7 +101,12 @@ class Flux2KleinTextEncoderInvocation(BaseInvocation):
|
||||
tokenizer_info = context.models.load(self.qwen3_encoder.tokenizer)
|
||||
(_, tokenizer) = exit_stack.enter_context(tokenizer_info.model_on_device())
|
||||
|
||||
device = text_encoder.device
|
||||
repaired_tensors = text_encoder_info.repair_required_tensors_on_device()
|
||||
device = get_effective_device(text_encoder)
|
||||
if repaired_tensors > 0:
|
||||
context.logger.warning(
|
||||
f"Recovered {repaired_tensors} required Qwen3 tensor(s) onto {device} after a partial device mismatch."
|
||||
)
|
||||
|
||||
# Apply LoRA models
|
||||
lora_dtype = TorchDevice.choose_bfloat16_safe_dtype(device)
|
||||
|
||||
@@ -16,6 +16,7 @@ from invokeai.app.invocations.fields import (
|
||||
from invokeai.app.invocations.model import Qwen3EncoderField
|
||||
from invokeai.app.invocations.primitives import ZImageConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.lora_conversions.z_image_lora_constants import Z_IMAGE_LORA_QWEN3_PREFIX
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
@@ -76,11 +77,17 @@ class ZImageTextEncoderInvocation(BaseInvocation):
|
||||
tokenizer_info = context.models.load(self.qwen3_encoder.tokenizer)
|
||||
|
||||
with ExitStack() as exit_stack:
|
||||
(_, text_encoder) = exit_stack.enter_context(text_encoder_info.model_on_device())
|
||||
(cached_weights, text_encoder) = exit_stack.enter_context(text_encoder_info.model_on_device())
|
||||
(_, tokenizer) = exit_stack.enter_context(tokenizer_info.model_on_device())
|
||||
|
||||
# Use the device that the text_encoder is actually on
|
||||
device = text_encoder.device
|
||||
# Use the device that the text encoder is effectively executing on, and repair any required tensors left on
|
||||
# the CPU by a previous interrupted run.
|
||||
repaired_tensors = text_encoder_info.repair_required_tensors_on_device()
|
||||
device = get_effective_device(text_encoder)
|
||||
if repaired_tensors > 0:
|
||||
context.logger.warning(
|
||||
f"Recovered {repaired_tensors} required Qwen3 tensor(s) onto {device} after a partial device mismatch."
|
||||
)
|
||||
|
||||
# Apply LoRA models to the text encoder
|
||||
lora_dtype = TorchDevice.choose_bfloat16_safe_dtype(device)
|
||||
@@ -90,6 +97,7 @@ class ZImageTextEncoderInvocation(BaseInvocation):
|
||||
patches=self._lora_iterator(context),
|
||||
prefix=Z_IMAGE_LORA_QWEN3_PREFIX,
|
||||
dtype=lora_dtype,
|
||||
cached_weights=cached_weights,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -14,6 +14,9 @@ import torch
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_manager.configs.factory import AnyModelConfig
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
|
||||
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import (
|
||||
CachedModelWithPartialLoad,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
|
||||
from invokeai.backend.model_manager.taxonomy import AnyModel, SubModelType
|
||||
|
||||
@@ -80,6 +83,13 @@ class LoadedModelWithoutConfig:
|
||||
"""Return the model without locking it."""
|
||||
return self._cache_record.cached_model.model
|
||||
|
||||
def repair_required_tensors_on_device(self) -> int:
|
||||
"""Repair required tensors that should be resident on the cached model's execution device."""
|
||||
cached_model = self._cache_record.cached_model
|
||||
if not isinstance(cached_model, CachedModelWithPartialLoad):
|
||||
return 0
|
||||
return cached_model.repair_required_tensors_on_compute_device()
|
||||
|
||||
|
||||
class LoadedModel(LoadedModelWithoutConfig):
|
||||
"""Context manager object that mediates transfer from RAM<->VRAM."""
|
||||
|
||||
@@ -149,6 +149,27 @@ class CachedModelWithPartialLoad:
|
||||
"""Unload all weights from VRAM."""
|
||||
return self.partial_unload_from_vram(self.total_bytes())
|
||||
|
||||
@torch.no_grad()
|
||||
def repair_required_tensors_on_compute_device(self) -> int:
|
||||
"""Repair required non-autocast tensors that were left off the compute device.
|
||||
|
||||
This can happen if an interrupted run leaves the model in a partially inconsistent state. Any repaired device
|
||||
movement invalidates the cached VRAM accounting.
|
||||
"""
|
||||
cur_state_dict = self._model.state_dict()
|
||||
keys_to_repair = {
|
||||
key
|
||||
for key in self._keys_in_modules_that_do_not_support_autocast
|
||||
if cur_state_dict[key].device.type != self._compute_device.type
|
||||
}
|
||||
if len(keys_to_repair) == 0:
|
||||
return 0
|
||||
|
||||
self._load_state_dict_with_device_conversion(cur_state_dict, keys_to_repair, self._compute_device)
|
||||
self._move_non_persistent_buffers_to_device(self._compute_device)
|
||||
self._cur_vram_bytes = None
|
||||
return len(keys_to_repair)
|
||||
|
||||
def _load_state_dict_with_device_conversion(
|
||||
self, state_dict: dict[str, torch.Tensor], keys_to_convert: set[str], target_device: torch.device
|
||||
):
|
||||
|
||||
@@ -4,7 +4,8 @@
|
||||
"uploadImage": "Lataa kuva",
|
||||
"invokeProgressBar": "Invoken edistymispalkki",
|
||||
"nextImage": "Seuraava kuva",
|
||||
"previousImage": "Edellinen kuva"
|
||||
"previousImage": "Edellinen kuva",
|
||||
"uploadImages": "Lähetä Kuva(t)"
|
||||
},
|
||||
"common": {
|
||||
"languagePickerLabel": "Kielen valinta",
|
||||
@@ -29,5 +30,28 @@
|
||||
"galleryImageSize": "Kuvan koko",
|
||||
"gallerySettings": "Gallerian asetukset",
|
||||
"autoSwitchNewImages": "Vaihda uusiin kuviin automaattisesti"
|
||||
},
|
||||
"modelManager": {
|
||||
"t5Encoder": "T5-kooderi",
|
||||
"qwen3Encoder": "Qwen3-kooderi",
|
||||
"zImageVae": "VAE (valinnainen)",
|
||||
"zImageQwen3Encoder": "Qwen3-kooderi (valinnainen)",
|
||||
"zImageQwen3SourcePlaceholder": "Pakollinen, jos VAE/Enkooderi on tyhjä",
|
||||
"flux2KleinVae": "VAE (valinnainen)",
|
||||
"flux2KleinQwen3Encoder": "Qwen3-kooderi (valinnainen)"
|
||||
},
|
||||
"auth": {
|
||||
"login": {
|
||||
"title": "Kirjaudu sisään InvokeAI:hin",
|
||||
"password": "Salasana",
|
||||
"passwordPlaceholder": "Salasana",
|
||||
"signIn": "Kirjaudu sisään",
|
||||
"signingIn": "Kirjaudutaan sisään...",
|
||||
"loginFailed": "Kirjautuminen epäonnistui. Tarkista käyttäjätunnuksesi tiedot."
|
||||
},
|
||||
"setup": {
|
||||
"title": "Tervetuloa InvokeAI:hin",
|
||||
"subtitle": "Määritä ensimmäiseksi järjestelmänvalvojan tili"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3139,6 +3139,11 @@
|
||||
"back": "Indietro",
|
||||
"cannotDeleteSelf": "Non puoi eliminare il tuo account",
|
||||
"cannotDeactivateSelf": "Non puoi disattivare il tuo account"
|
||||
},
|
||||
"passwordStrength": {
|
||||
"weak": "Password debole",
|
||||
"moderate": "Password moderata",
|
||||
"strong": "Password forte"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,7 +9,8 @@ import { rasterLayerGlobalCompositeOperationChanged } from 'features/controlLaye
|
||||
import type { CanvasEntityIdentifier, CompositeOperation } from 'features/controlLayers/store/types';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { CgPathBack, CgPathCrop, CgPathExclude, CgPathFront, CgPathIntersect } from 'react-icons/cg';
|
||||
import { CgPathBack, CgPathExclude, CgPathFront, CgPathIntersect } from 'react-icons/cg';
|
||||
import { PiIntersectSquareBold } from 'react-icons/pi';
|
||||
|
||||
export const RasterLayerMenuItemsBooleanSubMenu = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
@@ -48,7 +49,7 @@ export const RasterLayerMenuItemsBooleanSubMenu = memo(() => {
|
||||
const disabled = isBusy || !entityIdentifierBelowThisOne;
|
||||
|
||||
return (
|
||||
<MenuItem {...subMenu.parentMenuItemProps} isDisabled={disabled} icon={<CgPathCrop size={18} />}>
|
||||
<MenuItem {...subMenu.parentMenuItemProps} isDisabled={disabled} icon={<PiIntersectSquareBold />}>
|
||||
<Menu {...subMenu.menuProps}>
|
||||
<MenuButton {...subMenu.menuButtonProps}>
|
||||
<SubMenuButtonContent label={t('controlLayers.booleanOps.label')} />
|
||||
|
||||
80
tests/app/invocations/test_cogview4_text_encoder.py
Normal file
80
tests/app/invocations/test_cogview4_text_encoder.py
Normal file
@@ -0,0 +1,80 @@
|
||||
from contextlib import contextmanager
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.app.invocations.cogview4_text_encoder import CogView4TextEncoderInvocation
|
||||
|
||||
|
||||
class FakeGlmModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_parameter("weight", torch.nn.Parameter(torch.ones(1)))
|
||||
self.repaired = False
|
||||
self.forward_input_device: torch.device | None = None
|
||||
|
||||
def forward(self, input_ids: torch.Tensor, output_hidden_states: bool = False):
|
||||
assert output_hidden_states
|
||||
if not self.repaired:
|
||||
raise RuntimeError("model must be repaired before forward")
|
||||
|
||||
self.forward_input_device = input_ids.device
|
||||
hidden = input_ids.unsqueeze(-1).float()
|
||||
return SimpleNamespace(hidden_states=[hidden, hidden + 1])
|
||||
|
||||
|
||||
class FakeTokenizer:
|
||||
pad_token_id = 0
|
||||
|
||||
def __call__(self, prompt, padding, max_length=None, truncation=None, add_special_tokens=None, return_tensors=None):
|
||||
del prompt, padding, max_length, truncation, add_special_tokens, return_tensors
|
||||
return SimpleNamespace(input_ids=torch.tensor([[1, 2, 3]], dtype=torch.long))
|
||||
|
||||
def batch_decode(self, input_ids):
|
||||
del input_ids
|
||||
return ["decoded"]
|
||||
|
||||
|
||||
class FakeLoadedModel:
|
||||
def __init__(self, model):
|
||||
self._model = model
|
||||
self.repair_calls = 0
|
||||
|
||||
@contextmanager
|
||||
def model_on_device(self):
|
||||
yield (None, self._model)
|
||||
|
||||
def repair_required_tensors_on_device(self) -> int:
|
||||
self.repair_calls += 1
|
||||
self._model.repaired = True
|
||||
return 1
|
||||
|
||||
|
||||
def test_cogview4_text_encoder_repairs_model_before_forward(monkeypatch):
|
||||
fake_model = FakeGlmModel()
|
||||
fake_tokenizer = FakeTokenizer()
|
||||
fake_model_info = FakeLoadedModel(fake_model)
|
||||
fake_tokenizer_info = FakeLoadedModel(fake_tokenizer)
|
||||
|
||||
mock_context = MagicMock()
|
||||
mock_context.models.load.side_effect = [fake_model_info, fake_tokenizer_info]
|
||||
mock_context.util.signal_progress = MagicMock()
|
||||
mock_context.logger.warning = MagicMock()
|
||||
|
||||
invocation = CogView4TextEncoderInvocation.model_construct(
|
||||
prompt="test prompt",
|
||||
glm_encoder=SimpleNamespace(text_encoder=SimpleNamespace(), tokenizer=SimpleNamespace()),
|
||||
)
|
||||
|
||||
module_path = "invokeai.app.invocations.cogview4_text_encoder"
|
||||
monkeypatch.setattr(f"{module_path}.GlmModel", FakeGlmModel)
|
||||
monkeypatch.setattr(f"{module_path}.PreTrainedTokenizerFast", FakeTokenizer)
|
||||
|
||||
embeds = invocation._glm_encode(mock_context, max_seq_len=16)
|
||||
|
||||
assert fake_model_info.repair_calls == 1
|
||||
mock_context.logger.warning.assert_called_once()
|
||||
mock_context.util.signal_progress.assert_called_once_with("Running GLM text encoder")
|
||||
assert fake_model.forward_input_device == torch.device("cpu")
|
||||
assert embeds.shape == (1, 16, 1)
|
||||
@@ -0,0 +1,47 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import (
|
||||
CachedModelWithPartialLoad,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import (
|
||||
apply_custom_layers_to_model,
|
||||
)
|
||||
|
||||
|
||||
class ModelWithRequiredScale(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(4, 4)
|
||||
self.scale = torch.nn.Parameter(torch.ones(4))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.linear(x) * self.scale
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"device",
|
||||
[
|
||||
pytest.param(
|
||||
torch.device("cuda"), marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")
|
||||
),
|
||||
pytest.param(
|
||||
torch.device("mps"),
|
||||
marks=pytest.mark.skipif(not torch.backends.mps.is_available(), reason="requires MPS device"),
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("keep_ram_copy", [True, False])
|
||||
@torch.no_grad()
|
||||
def test_repair_required_tensors_on_compute_device(device: torch.device, keep_ram_copy: bool):
|
||||
model = ModelWithRequiredScale()
|
||||
apply_custom_layers_to_model(model, device_autocasting_enabled=True)
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=device, keep_ram_copy=keep_ram_copy)
|
||||
|
||||
cached_model._cur_vram_bytes = 0
|
||||
repaired_tensors = cached_model.repair_required_tensors_on_compute_device()
|
||||
|
||||
assert repaired_tensors == 1
|
||||
assert cached_model._cur_vram_bytes is None
|
||||
assert model.scale.device.type == device.type
|
||||
assert all(param.device.type == "cpu" for param in model.linear.parameters())
|
||||
Reference in New Issue
Block a user