mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-15 09:18:00 -05:00
Compare commits
81 Commits
ryan/fix-d
...
ryan/model
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2144d21f80 | ||
|
|
958efa19d7 | ||
|
|
11af57def3 | ||
|
|
8b70a5b9bd | ||
|
|
5d9fdcd78d | ||
|
|
c7b84cf012 | ||
|
|
8e409e3436 | ||
|
|
987393853c | ||
|
|
91c5af1b95 | ||
|
|
5c67dd507a | ||
|
|
2ff928ec17 | ||
|
|
4327bbe77e | ||
|
|
ad1c0d37ef | ||
|
|
9708d87946 | ||
|
|
3ad44f7850 | ||
|
|
9a482981b2 | ||
|
|
6b02362b12 | ||
|
|
8fec4ec91c | ||
|
|
693e421970 | ||
|
|
dc14104bc8 | ||
|
|
f286a1d1f3 | ||
|
|
9dc86b2b71 | ||
|
|
2cab689b79 | ||
|
|
f8c7adddd0 | ||
|
|
17da1d92e9 | ||
|
|
1cc57a4854 | ||
|
|
3993fae331 | ||
|
|
1446526d55 | ||
|
|
62c024e725 | ||
|
|
1e92bb4e94 | ||
|
|
db6398fdf6 | ||
|
|
ebd73a2ac2 | ||
|
|
8ee95cab00 | ||
|
|
d1184201a8 | ||
|
|
5887891654 | ||
|
|
765ca4e004 | ||
|
|
159b00a490 | ||
|
|
3fbf6f2d2a | ||
|
|
931fca7cd1 | ||
|
|
db84a3a5d4 | ||
|
|
ca8313e805 | ||
|
|
df849035ee | ||
|
|
8d97fe69ca | ||
|
|
9044e53a9b | ||
|
|
6012b0f912 | ||
|
|
bb0ed5dc8a | ||
|
|
021552fd81 | ||
|
|
be73dbba92 | ||
|
|
db9c0cad7c | ||
|
|
54b7f9a063 | ||
|
|
7d488a5352 | ||
|
|
4d7667f63d | ||
|
|
08704ee8ec | ||
|
|
5910892c33 | ||
|
|
46a09d9e90 | ||
|
|
df0c7d73f3 | ||
|
|
3905c97e32 | ||
|
|
0be796a808 | ||
|
|
7dd33b0f39 | ||
|
|
484aaf1595 | ||
|
|
c276b60af9 | ||
|
|
5d8dd6e26e | ||
|
|
5bca68d873 | ||
|
|
64364e7911 | ||
|
|
6565cea039 | ||
|
|
3ebd8d6c07 | ||
|
|
e970185161 | ||
|
|
fa5653cdf7 | ||
|
|
9a7b000995 | ||
|
|
3a27242838 | ||
|
|
b54463d294 | ||
|
|
faee79dc95 | ||
|
|
e01f66b026 | ||
|
|
53abdde242 | ||
|
|
94c088300f | ||
|
|
3741a6f5e0 | ||
|
|
2c23b8414c | ||
|
|
20356c0746 | ||
|
|
bad1149504 | ||
|
|
fda7aaa7ca | ||
|
|
85c616fa34 |
@@ -1364,7 +1364,6 @@ the in-memory loaded model:
|
||||
|----------------|-----------------|------------------|
|
||||
| `config` | AnyModelConfig | A copy of the model's configuration record for retrieving base type, etc. |
|
||||
| `model` | AnyModel | The instantiated model (details below) |
|
||||
| `locker` | ModelLockerBase | A context manager that mediates the movement of the model into VRAM |
|
||||
|
||||
### get_model_by_key(key, [submodel]) -> LoadedModel
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ This project is a combined effort of dedicated people from across the world. [C
|
||||
|
||||
## Code of Conduct
|
||||
|
||||
The InvokeAI community is a welcoming place, and we want your help in maintaining that. Please review our [Code of Conduct](https://github.com/invoke-ai/InvokeAI/blob/main/CODE_OF_CONDUCT.md) to learn more - it's essential to maintaining a respectful and inclusive environment.
|
||||
The InvokeAI community is a welcoming place, and we want your help in maintaining that. Please review our [Code of Conduct](https://github.com/invoke-ai/InvokeAI/blob/main/docs/CODE_OF_CONDUCT.md) to learn more - it's essential to maintaining a respectful and inclusive environment.
|
||||
|
||||
By making a contribution to this project, you certify that:
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import CacheStats
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
|
||||
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
|
||||
@@ -110,7 +110,7 @@ async def cancel_by_batch_ids(
|
||||
@session_queue_router.put(
|
||||
"/{queue_id}/cancel_by_destination",
|
||||
operation_id="cancel_by_destination",
|
||||
responses={200: {"model": CancelByBatchIDsResult}},
|
||||
responses={200: {"model": CancelByDestinationResult}},
|
||||
)
|
||||
async def cancel_by_destination(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
|
||||
@@ -250,6 +250,11 @@ class FluxConditioningField(BaseModel):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||
mask: Optional[TensorField] = Field(
|
||||
default=None,
|
||||
description="The mask associated with this conditioning tensor. Excluded regions should be set to False, "
|
||||
"included regions should be set to True.",
|
||||
)
|
||||
|
||||
|
||||
class SD3ConditioningField(BaseModel):
|
||||
|
||||
@@ -30,6 +30,7 @@ from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlN
|
||||
from invokeai.backend.flux.denoise import denoise
|
||||
from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension
|
||||
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
|
||||
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
|
||||
from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import XlabsIpAdapterFlux
|
||||
@@ -42,6 +43,7 @@ from invokeai.backend.flux.sampling_utils import (
|
||||
pack,
|
||||
unpack,
|
||||
)
|
||||
from invokeai.backend.flux.text_conditioning import FluxTextConditioning
|
||||
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
||||
@@ -56,7 +58,7 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
title="FLUX Denoise",
|
||||
tags=["image", "flux"],
|
||||
category="image",
|
||||
version="3.2.1",
|
||||
version="3.2.2",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
@@ -87,10 +89,10 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
input=Input.Connection,
|
||||
title="Transformer",
|
||||
)
|
||||
positive_text_conditioning: FluxConditioningField = InputField(
|
||||
positive_text_conditioning: FluxConditioningField | list[FluxConditioningField] = InputField(
|
||||
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||
)
|
||||
negative_text_conditioning: FluxConditioningField | None = InputField(
|
||||
negative_text_conditioning: FluxConditioningField | list[FluxConditioningField] | None = InputField(
|
||||
default=None,
|
||||
description="Negative conditioning tensor. Can be None if cfg_scale is 1.0.",
|
||||
input=Input.Connection,
|
||||
@@ -139,36 +141,12 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
name = context.tensors.save(tensor=latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||
|
||||
def _load_text_conditioning(
|
||||
self, context: InvocationContext, conditioning_name: str, dtype: torch.dtype
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Load the conditioning data.
|
||||
cond_data = context.conditioning.load(conditioning_name)
|
||||
assert len(cond_data.conditionings) == 1
|
||||
flux_conditioning = cond_data.conditionings[0]
|
||||
assert isinstance(flux_conditioning, FLUXConditioningInfo)
|
||||
flux_conditioning = flux_conditioning.to(dtype=dtype)
|
||||
t5_embeddings = flux_conditioning.t5_embeds
|
||||
clip_embeddings = flux_conditioning.clip_embeds
|
||||
return t5_embeddings, clip_embeddings
|
||||
|
||||
def _run_diffusion(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
):
|
||||
inference_dtype = torch.bfloat16
|
||||
|
||||
# Load the conditioning data.
|
||||
pos_t5_embeddings, pos_clip_embeddings = self._load_text_conditioning(
|
||||
context, self.positive_text_conditioning.conditioning_name, inference_dtype
|
||||
)
|
||||
neg_t5_embeddings: torch.Tensor | None = None
|
||||
neg_clip_embeddings: torch.Tensor | None = None
|
||||
if self.negative_text_conditioning is not None:
|
||||
neg_t5_embeddings, neg_clip_embeddings = self._load_text_conditioning(
|
||||
context, self.negative_text_conditioning.conditioning_name, inference_dtype
|
||||
)
|
||||
|
||||
# Load the input latents, if provided.
|
||||
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
|
||||
if init_latents is not None:
|
||||
@@ -183,15 +161,45 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
dtype=inference_dtype,
|
||||
seed=self.seed,
|
||||
)
|
||||
b, _c, latent_h, latent_w = noise.shape
|
||||
packed_h = latent_h // 2
|
||||
packed_w = latent_w // 2
|
||||
|
||||
# Load the conditioning data.
|
||||
pos_text_conditionings = self._load_text_conditioning(
|
||||
context=context,
|
||||
cond_field=self.positive_text_conditioning,
|
||||
packed_height=packed_h,
|
||||
packed_width=packed_w,
|
||||
dtype=inference_dtype,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
)
|
||||
neg_text_conditionings: list[FluxTextConditioning] | None = None
|
||||
if self.negative_text_conditioning is not None:
|
||||
neg_text_conditionings = self._load_text_conditioning(
|
||||
context=context,
|
||||
cond_field=self.negative_text_conditioning,
|
||||
packed_height=packed_h,
|
||||
packed_width=packed_w,
|
||||
dtype=inference_dtype,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
)
|
||||
pos_regional_prompting_extension = RegionalPromptingExtension.from_text_conditioning(
|
||||
pos_text_conditionings, img_seq_len=packed_h * packed_w
|
||||
)
|
||||
neg_regional_prompting_extension = (
|
||||
RegionalPromptingExtension.from_text_conditioning(neg_text_conditionings, img_seq_len=packed_h * packed_w)
|
||||
if neg_text_conditionings
|
||||
else None
|
||||
)
|
||||
|
||||
transformer_info = context.models.load(self.transformer.transformer)
|
||||
is_schnell = "schnell" in transformer_info.config.config_path
|
||||
|
||||
# Calculate the timestep schedule.
|
||||
image_seq_len = noise.shape[-1] * noise.shape[-2] // 4
|
||||
timesteps = get_schedule(
|
||||
num_steps=self.num_steps,
|
||||
image_seq_len=image_seq_len,
|
||||
image_seq_len=packed_h * packed_w,
|
||||
shift=not is_schnell,
|
||||
)
|
||||
|
||||
@@ -228,28 +236,17 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
inpaint_mask = self._prep_inpaint_mask(context, x)
|
||||
|
||||
b, _c, latent_h, latent_w = x.shape
|
||||
img_ids = generate_img_ids(h=latent_h, w=latent_w, batch_size=b, device=x.device, dtype=x.dtype)
|
||||
|
||||
pos_bs, pos_t5_seq_len, _ = pos_t5_embeddings.shape
|
||||
pos_txt_ids = torch.zeros(
|
||||
pos_bs, pos_t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device()
|
||||
)
|
||||
neg_txt_ids: torch.Tensor | None = None
|
||||
if neg_t5_embeddings is not None:
|
||||
neg_bs, neg_t5_seq_len, _ = neg_t5_embeddings.shape
|
||||
neg_txt_ids = torch.zeros(
|
||||
neg_bs, neg_t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device()
|
||||
)
|
||||
|
||||
# Pack all latent tensors.
|
||||
init_latents = pack(init_latents) if init_latents is not None else None
|
||||
inpaint_mask = pack(inpaint_mask) if inpaint_mask is not None else None
|
||||
noise = pack(noise)
|
||||
x = pack(x)
|
||||
|
||||
# Now that we have 'packed' the latent tensors, verify that we calculated the image_seq_len correctly.
|
||||
assert image_seq_len == x.shape[1]
|
||||
# Now that we have 'packed' the latent tensors, verify that we calculated the image_seq_len, packed_h, and
|
||||
# packed_w correctly.
|
||||
assert packed_h * packed_w == x.shape[1]
|
||||
|
||||
# Prepare inpaint extension.
|
||||
inpaint_extension: InpaintExtension | None = None
|
||||
@@ -338,12 +335,8 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
model=transformer,
|
||||
img=x,
|
||||
img_ids=img_ids,
|
||||
txt=pos_t5_embeddings,
|
||||
txt_ids=pos_txt_ids,
|
||||
vec=pos_clip_embeddings,
|
||||
neg_txt=neg_t5_embeddings,
|
||||
neg_txt_ids=neg_txt_ids,
|
||||
neg_vec=neg_clip_embeddings,
|
||||
pos_regional_prompting_extension=pos_regional_prompting_extension,
|
||||
neg_regional_prompting_extension=neg_regional_prompting_extension,
|
||||
timesteps=timesteps,
|
||||
step_callback=self._build_step_callback(context),
|
||||
guidance=self.guidance,
|
||||
@@ -357,6 +350,43 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
x = unpack(x.float(), self.height, self.width)
|
||||
return x
|
||||
|
||||
def _load_text_conditioning(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
cond_field: FluxConditioningField | list[FluxConditioningField],
|
||||
packed_height: int,
|
||||
packed_width: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> list[FluxTextConditioning]:
|
||||
"""Load text conditioning data from a FluxConditioningField or a list of FluxConditioningFields."""
|
||||
# Normalize to a list of FluxConditioningFields.
|
||||
cond_list = [cond_field] if isinstance(cond_field, FluxConditioningField) else cond_field
|
||||
|
||||
text_conditionings: list[FluxTextConditioning] = []
|
||||
for cond_field in cond_list:
|
||||
# Load the text embeddings.
|
||||
cond_data = context.conditioning.load(cond_field.conditioning_name)
|
||||
assert len(cond_data.conditionings) == 1
|
||||
flux_conditioning = cond_data.conditionings[0]
|
||||
assert isinstance(flux_conditioning, FLUXConditioningInfo)
|
||||
flux_conditioning = flux_conditioning.to(dtype=dtype, device=device)
|
||||
t5_embeddings = flux_conditioning.t5_embeds
|
||||
clip_embeddings = flux_conditioning.clip_embeds
|
||||
|
||||
# Load the mask, if provided.
|
||||
mask: Optional[torch.Tensor] = None
|
||||
if cond_field.mask is not None:
|
||||
mask = context.tensors.load(cond_field.mask.tensor_name)
|
||||
mask = mask.to(device=device)
|
||||
mask = RegionalPromptingExtension.preprocess_regional_prompt_mask(
|
||||
mask, packed_height, packed_width, dtype, device
|
||||
)
|
||||
|
||||
text_conditionings.append(FluxTextConditioning(t5_embeddings, clip_embeddings, mask))
|
||||
|
||||
return text_conditionings
|
||||
|
||||
@classmethod
|
||||
def prep_cfg_scale(
|
||||
cls, cfg_scale: float | list[float], timesteps: list[float], cfg_scale_start_step: int, cfg_scale_end_step: int
|
||||
|
||||
@@ -1,11 +1,18 @@
|
||||
from contextlib import ExitStack
|
||||
from typing import Iterator, Literal, Tuple
|
||||
from typing import Iterator, Literal, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, UIComponent
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
FluxConditioningField,
|
||||
Input,
|
||||
InputField,
|
||||
TensorField,
|
||||
UIComponent,
|
||||
)
|
||||
from invokeai.app.invocations.model import CLIPField, T5EncoderField
|
||||
from invokeai.app.invocations.primitives import FluxConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
@@ -22,7 +29,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Condit
|
||||
title="FLUX Text Encoding",
|
||||
tags=["prompt", "conditioning", "flux"],
|
||||
category="conditioning",
|
||||
version="1.1.0",
|
||||
version="1.1.1",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxTextEncoderInvocation(BaseInvocation):
|
||||
@@ -41,9 +48,9 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
t5_max_seq_len: Literal[256, 512] = InputField(
|
||||
description="Max sequence length for the T5 encoder. Expected to be 256 for FLUX schnell models and 512 for FLUX dev models."
|
||||
)
|
||||
prompt: str = InputField(
|
||||
description="Text prompt to encode.",
|
||||
ui_component=UIComponent.Textarea,
|
||||
prompt: str = InputField(description="Text prompt to encode.", ui_component=UIComponent.Textarea)
|
||||
mask: Optional[TensorField] = InputField(
|
||||
default=None, description="A mask defining the region that this conditioning prompt applies to."
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -57,7 +64,9 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
return FluxConditioningOutput.build(conditioning_name)
|
||||
return FluxConditioningOutput(
|
||||
conditioning=FluxConditioningField(conditioning_name=conditioning_name, mask=self.mask)
|
||||
)
|
||||
|
||||
def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
|
||||
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
|
||||
|
||||
@@ -20,7 +20,7 @@ from invokeai.app.services.invocation_stats.invocation_stats_common import (
|
||||
NodeExecutionStatsSummary,
|
||||
)
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.backend.model_manager.load.model_cache import CacheStats
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
|
||||
|
||||
# Size of 1GB in bytes.
|
||||
GB = 2**30
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Callable, Optional
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||
from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
|
||||
|
||||
|
||||
class ModelLoadServiceBase(ABC):
|
||||
@@ -24,7 +24,7 @@ class ModelLoadServiceBase(ABC):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
||||
def ram_cache(self) -> ModelCache:
|
||||
"""Return the RAM cache used by this loader."""
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -18,7 +18,7 @@ from invokeai.backend.model_manager.load import (
|
||||
ModelLoaderRegistry,
|
||||
ModelLoaderRegistryBase,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
|
||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
@@ -30,7 +30,7 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
def __init__(
|
||||
self,
|
||||
app_config: InvokeAIAppConfig,
|
||||
ram_cache: ModelCacheBase[AnyModel],
|
||||
ram_cache: ModelCache,
|
||||
registry: Optional[Type[ModelLoaderRegistryBase]] = ModelLoaderRegistry,
|
||||
):
|
||||
"""Initialize the model load service."""
|
||||
@@ -45,7 +45,7 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
self._invoker = invoker
|
||||
|
||||
@property
|
||||
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
||||
def ram_cache(self) -> ModelCache:
|
||||
"""Return the RAM cache used by this loader."""
|
||||
return self._ram_cache
|
||||
|
||||
@@ -78,9 +78,8 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
|
||||
) -> LoadedModelWithoutConfig:
|
||||
cache_key = str(model_path)
|
||||
ram_cache = self.ram_cache
|
||||
try:
|
||||
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))
|
||||
return LoadedModelWithoutConfig(cache_record=self._ram_cache.get(key=cache_key), cache=self._ram_cache)
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
@@ -109,5 +108,5 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
)
|
||||
assert loader is not None
|
||||
raw_model = loader(model_path)
|
||||
ram_cache.put(key=cache_key, model=raw_model)
|
||||
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))
|
||||
self._ram_cache.put(key=cache_key, model=raw_model)
|
||||
return LoadedModelWithoutConfig(cache_record=self._ram_cache.get(key=cache_key), cache=self._ram_cache)
|
||||
|
||||
@@ -16,7 +16,8 @@ from invokeai.app.services.model_load.model_load_base import ModelLoadServiceBas
|
||||
from invokeai.app.services.model_load.model_load_default import ModelLoadService
|
||||
from invokeai.app.services.model_manager.model_manager_base import ModelManagerServiceBase
|
||||
from invokeai.app.services.model_records.model_records_base import ModelRecordServiceBase
|
||||
from invokeai.backend.model_manager.load import ModelCache, ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import einops
|
||||
import torch
|
||||
|
||||
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
|
||||
from invokeai.backend.flux.math import attention
|
||||
from invokeai.backend.flux.modules.layers import DoubleStreamBlock
|
||||
from invokeai.backend.flux.modules.layers import DoubleStreamBlock, SingleStreamBlock
|
||||
|
||||
|
||||
class CustomDoubleStreamBlockProcessor:
|
||||
@@ -13,7 +14,12 @@ class CustomDoubleStreamBlockProcessor:
|
||||
|
||||
@staticmethod
|
||||
def _double_stream_block_forward(
|
||||
block: DoubleStreamBlock, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, pe: torch.Tensor
|
||||
block: DoubleStreamBlock,
|
||||
img: torch.Tensor,
|
||||
txt: torch.Tensor,
|
||||
vec: torch.Tensor,
|
||||
pe: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""This function is a direct copy of DoubleStreamBlock.forward(), but it returns some of the intermediate
|
||||
values.
|
||||
@@ -40,7 +46,7 @@ class CustomDoubleStreamBlockProcessor:
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
|
||||
attn = attention(q, k, v, pe=pe)
|
||||
attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
|
||||
# calculate the img bloks
|
||||
@@ -63,11 +69,15 @@ class CustomDoubleStreamBlockProcessor:
|
||||
vec: torch.Tensor,
|
||||
pe: torch.Tensor,
|
||||
ip_adapter_extensions: list[XLabsIPAdapterExtension],
|
||||
regional_prompting_extension: RegionalPromptingExtension,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""A custom implementation of DoubleStreamBlock.forward() with additional features:
|
||||
- IP-Adapter support
|
||||
"""
|
||||
img, txt, img_q = CustomDoubleStreamBlockProcessor._double_stream_block_forward(block, img, txt, vec, pe)
|
||||
attn_mask = regional_prompting_extension.get_double_stream_attn_mask(block_index)
|
||||
img, txt, img_q = CustomDoubleStreamBlockProcessor._double_stream_block_forward(
|
||||
block, img, txt, vec, pe, attn_mask=attn_mask
|
||||
)
|
||||
|
||||
# Apply IP-Adapter conditioning.
|
||||
for ip_adapter_extension in ip_adapter_extensions:
|
||||
@@ -81,3 +91,48 @@ class CustomDoubleStreamBlockProcessor:
|
||||
)
|
||||
|
||||
return img, txt
|
||||
|
||||
|
||||
class CustomSingleStreamBlockProcessor:
|
||||
"""A class containing a custom implementation of SingleStreamBlock.forward() with additional features (masking,
|
||||
etc.)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _single_stream_block_forward(
|
||||
block: SingleStreamBlock,
|
||||
x: torch.Tensor,
|
||||
vec: torch.Tensor,
|
||||
pe: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""This function is a direct copy of SingleStreamBlock.forward()."""
|
||||
mod, _ = block.modulation(vec)
|
||||
x_mod = (1 + mod.scale) * block.pre_norm(x) + mod.shift
|
||||
qkv, mlp = torch.split(block.linear1(x_mod), [3 * block.hidden_size, block.mlp_hidden_dim], dim=-1)
|
||||
|
||||
q, k, v = einops.rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=block.num_heads)
|
||||
q, k = block.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = block.linear2(torch.cat((attn, block.mlp_act(mlp)), 2))
|
||||
return x + mod.gate * output
|
||||
|
||||
@staticmethod
|
||||
def custom_single_block_forward(
|
||||
timestep_index: int,
|
||||
total_num_timesteps: int,
|
||||
block_index: int,
|
||||
block: SingleStreamBlock,
|
||||
img: torch.Tensor,
|
||||
vec: torch.Tensor,
|
||||
pe: torch.Tensor,
|
||||
regional_prompting_extension: RegionalPromptingExtension,
|
||||
) -> torch.Tensor:
|
||||
"""A custom implementation of SingleStreamBlock.forward() with additional features:
|
||||
- Masking
|
||||
"""
|
||||
attn_mask = regional_prompting_extension.get_single_stream_attn_mask(block_index)
|
||||
return CustomSingleStreamBlockProcessor._single_stream_block_forward(block, img, vec, pe, attn_mask=attn_mask)
|
||||
|
||||
@@ -7,6 +7,7 @@ from tqdm import tqdm
|
||||
from invokeai.backend.flux.controlnet.controlnet_flux_output import ControlNetFluxOutput, sum_controlnet_flux_outputs
|
||||
from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension
|
||||
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
|
||||
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
|
||||
from invokeai.backend.flux.model import Flux
|
||||
@@ -18,14 +19,8 @@ def denoise(
|
||||
# model input
|
||||
img: torch.Tensor,
|
||||
img_ids: torch.Tensor,
|
||||
# positive text conditioning
|
||||
txt: torch.Tensor,
|
||||
txt_ids: torch.Tensor,
|
||||
vec: torch.Tensor,
|
||||
# negative text conditioning
|
||||
neg_txt: torch.Tensor | None,
|
||||
neg_txt_ids: torch.Tensor | None,
|
||||
neg_vec: torch.Tensor | None,
|
||||
pos_regional_prompting_extension: RegionalPromptingExtension,
|
||||
neg_regional_prompting_extension: RegionalPromptingExtension | None,
|
||||
# sampling parameters
|
||||
timesteps: list[float],
|
||||
step_callback: Callable[[PipelineIntermediateState], None],
|
||||
@@ -61,9 +56,9 @@ def denoise(
|
||||
total_num_timesteps=total_steps,
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
|
||||
txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
|
||||
y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
)
|
||||
@@ -78,9 +73,9 @@ def denoise(
|
||||
pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
|
||||
txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
|
||||
y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
timestep_index=step_index,
|
||||
@@ -88,6 +83,7 @@ def denoise(
|
||||
controlnet_double_block_residuals=merged_controlnet_residuals.double_block_residuals,
|
||||
controlnet_single_block_residuals=merged_controlnet_residuals.single_block_residuals,
|
||||
ip_adapter_extensions=pos_ip_adapter_extensions,
|
||||
regional_prompting_extension=pos_regional_prompting_extension,
|
||||
)
|
||||
|
||||
step_cfg_scale = cfg_scale[step_index]
|
||||
@@ -97,15 +93,15 @@ def denoise(
|
||||
# TODO(ryand): Add option to run positive and negative predictions in a single batch for better performance
|
||||
# on systems with sufficient VRAM.
|
||||
|
||||
if neg_txt is None or neg_txt_ids is None or neg_vec is None:
|
||||
if neg_regional_prompting_extension is None:
|
||||
raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.")
|
||||
|
||||
neg_pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=neg_txt,
|
||||
txt_ids=neg_txt_ids,
|
||||
y=neg_vec,
|
||||
txt=neg_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
|
||||
txt_ids=neg_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
|
||||
y=neg_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
timestep_index=step_index,
|
||||
@@ -113,6 +109,7 @@ def denoise(
|
||||
controlnet_double_block_residuals=None,
|
||||
controlnet_single_block_residuals=None,
|
||||
ip_adapter_extensions=neg_ip_adapter_extensions,
|
||||
regional_prompting_extension=neg_regional_prompting_extension,
|
||||
)
|
||||
pred = neg_pred + step_cfg_scale * (pred - neg_pred)
|
||||
|
||||
|
||||
276
invokeai/backend/flux/extensions/regional_prompting_extension.py
Normal file
276
invokeai/backend/flux/extensions/regional_prompting_extension.py
Normal file
@@ -0,0 +1,276 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
from invokeai.backend.flux.text_conditioning import FluxRegionalTextConditioning, FluxTextConditioning
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Range
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.mask import to_standard_float_mask
|
||||
|
||||
|
||||
class RegionalPromptingExtension:
|
||||
"""A class for managing regional prompting with FLUX.
|
||||
|
||||
This implementation is inspired by https://arxiv.org/pdf/2411.02395 (though there are significant differences).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
regional_text_conditioning: FluxRegionalTextConditioning,
|
||||
restricted_attn_mask: torch.Tensor | None = None,
|
||||
):
|
||||
self.regional_text_conditioning = regional_text_conditioning
|
||||
self.restricted_attn_mask = restricted_attn_mask
|
||||
|
||||
def get_double_stream_attn_mask(self, block_index: int) -> torch.Tensor | None:
|
||||
order = [self.restricted_attn_mask, None]
|
||||
return order[block_index % len(order)]
|
||||
|
||||
def get_single_stream_attn_mask(self, block_index: int) -> torch.Tensor | None:
|
||||
order = [self.restricted_attn_mask, None]
|
||||
return order[block_index % len(order)]
|
||||
|
||||
@classmethod
|
||||
def from_text_conditioning(cls, text_conditioning: list[FluxTextConditioning], img_seq_len: int):
|
||||
"""Create a RegionalPromptingExtension from a list of text conditionings.
|
||||
|
||||
Args:
|
||||
text_conditioning (list[FluxTextConditioning]): The text conditionings to use for regional prompting.
|
||||
img_seq_len (int): The image sequence length (i.e. packed_height * packed_width).
|
||||
"""
|
||||
regional_text_conditioning = cls._concat_regional_text_conditioning(text_conditioning)
|
||||
attn_mask_with_restricted_img_self_attn = cls._prepare_restricted_attn_mask(
|
||||
regional_text_conditioning, img_seq_len
|
||||
)
|
||||
return cls(
|
||||
regional_text_conditioning=regional_text_conditioning,
|
||||
restricted_attn_mask=attn_mask_with_restricted_img_self_attn,
|
||||
)
|
||||
|
||||
# Keeping _prepare_unrestricted_attn_mask for reference as an alternative masking strategy:
|
||||
#
|
||||
# @classmethod
|
||||
# def _prepare_unrestricted_attn_mask(
|
||||
# cls,
|
||||
# regional_text_conditioning: FluxRegionalTextConditioning,
|
||||
# img_seq_len: int,
|
||||
# ) -> torch.Tensor:
|
||||
# """Prepare an 'unrestricted' attention mask. In this context, 'unrestricted' means that:
|
||||
# - img self-attention is not masked.
|
||||
# - img regions attend to both txt within their own region and to global prompts.
|
||||
# """
|
||||
# device = TorchDevice.choose_torch_device()
|
||||
|
||||
# # Infer txt_seq_len from the t5_embeddings tensor.
|
||||
# txt_seq_len = regional_text_conditioning.t5_embeddings.shape[1]
|
||||
|
||||
# # In the attention blocks, the txt seq and img seq are concatenated and then attention is applied.
|
||||
# # Concatenation happens in the following order: [txt_seq, img_seq].
|
||||
# # There are 4 portions of the attention mask to consider as we prepare it:
|
||||
# # 1. txt attends to itself
|
||||
# # 2. txt attends to corresponding regional img
|
||||
# # 3. regional img attends to corresponding txt
|
||||
# # 4. regional img attends to itself
|
||||
|
||||
# # Initialize empty attention mask.
|
||||
# regional_attention_mask = torch.zeros(
|
||||
# (txt_seq_len + img_seq_len, txt_seq_len + img_seq_len), device=device, dtype=torch.float16
|
||||
# )
|
||||
|
||||
# for image_mask, t5_embedding_range in zip(
|
||||
# regional_text_conditioning.image_masks, regional_text_conditioning.t5_embedding_ranges, strict=True
|
||||
# ):
|
||||
# # 1. txt attends to itself
|
||||
# regional_attention_mask[
|
||||
# t5_embedding_range.start : t5_embedding_range.end, t5_embedding_range.start : t5_embedding_range.end
|
||||
# ] = 1.0
|
||||
|
||||
# # 2. txt attends to corresponding regional img
|
||||
# # Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired.
|
||||
# fill_value = image_mask.view(1, img_seq_len) if image_mask is not None else 1.0
|
||||
# regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = fill_value
|
||||
|
||||
# # 3. regional img attends to corresponding txt
|
||||
# # Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired.
|
||||
# fill_value = image_mask.view(img_seq_len, 1) if image_mask is not None else 1.0
|
||||
# regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = fill_value
|
||||
|
||||
# # 4. regional img attends to itself
|
||||
# # Allow unrestricted img self attention.
|
||||
# regional_attention_mask[txt_seq_len:, txt_seq_len:] = 1.0
|
||||
|
||||
# # Convert attention mask to boolean.
|
||||
# regional_attention_mask = regional_attention_mask > 0.5
|
||||
|
||||
# return regional_attention_mask
|
||||
|
||||
@classmethod
|
||||
def _prepare_restricted_attn_mask(
|
||||
cls,
|
||||
regional_text_conditioning: FluxRegionalTextConditioning,
|
||||
img_seq_len: int,
|
||||
) -> torch.Tensor | None:
|
||||
"""Prepare a 'restricted' attention mask. In this context, 'restricted' means that:
|
||||
- img self-attention is only allowed within regions.
|
||||
- img regions only attend to txt within their own region, not to global prompts.
|
||||
"""
|
||||
# Identify background region. I.e. the region that is not covered by any region masks.
|
||||
background_region_mask: None | torch.Tensor = None
|
||||
for image_mask in regional_text_conditioning.image_masks:
|
||||
if image_mask is not None:
|
||||
if background_region_mask is None:
|
||||
background_region_mask = torch.ones_like(image_mask)
|
||||
background_region_mask *= 1 - image_mask
|
||||
|
||||
if background_region_mask is None:
|
||||
# There are no region masks, short-circuit and return None.
|
||||
# TODO(ryand): We could restrict txt-txt attention across multiple global prompts, but this would
|
||||
# is a rare use case and would make the logic here significantly more complicated.
|
||||
return None
|
||||
|
||||
device = TorchDevice.choose_torch_device()
|
||||
|
||||
# Infer txt_seq_len from the t5_embeddings tensor.
|
||||
txt_seq_len = regional_text_conditioning.t5_embeddings.shape[1]
|
||||
|
||||
# In the attention blocks, the txt seq and img seq are concatenated and then attention is applied.
|
||||
# Concatenation happens in the following order: [txt_seq, img_seq].
|
||||
# There are 4 portions of the attention mask to consider as we prepare it:
|
||||
# 1. txt attends to itself
|
||||
# 2. txt attends to corresponding regional img
|
||||
# 3. regional img attends to corresponding txt
|
||||
# 4. regional img attends to itself
|
||||
|
||||
# Initialize empty attention mask.
|
||||
regional_attention_mask = torch.zeros(
|
||||
(txt_seq_len + img_seq_len, txt_seq_len + img_seq_len), device=device, dtype=torch.float16
|
||||
)
|
||||
|
||||
for image_mask, t5_embedding_range in zip(
|
||||
regional_text_conditioning.image_masks, regional_text_conditioning.t5_embedding_ranges, strict=True
|
||||
):
|
||||
# 1. txt attends to itself
|
||||
regional_attention_mask[
|
||||
t5_embedding_range.start : t5_embedding_range.end, t5_embedding_range.start : t5_embedding_range.end
|
||||
] = 1.0
|
||||
|
||||
if image_mask is not None:
|
||||
# 2. txt attends to corresponding regional img
|
||||
# Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired.
|
||||
regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = (
|
||||
image_mask.view(1, img_seq_len)
|
||||
)
|
||||
|
||||
# 3. regional img attends to corresponding txt
|
||||
# Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired.
|
||||
regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = (
|
||||
image_mask.view(img_seq_len, 1)
|
||||
)
|
||||
|
||||
# 4. regional img attends to itself
|
||||
image_mask = image_mask.view(img_seq_len, 1)
|
||||
regional_attention_mask[txt_seq_len:, txt_seq_len:] += image_mask @ image_mask.T
|
||||
else:
|
||||
# We don't allow attention between non-background image regions and global prompts. This helps to ensure
|
||||
# that regions focus on their local prompts. We do, however, allow attention between background regions
|
||||
# and global prompts. If we didn't do this, then the background regions would not attend to any txt
|
||||
# embeddings, which we found experimentally to cause artifacts.
|
||||
|
||||
# 2. global txt attends to background region
|
||||
# Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired.
|
||||
regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = (
|
||||
background_region_mask.view(1, img_seq_len)
|
||||
)
|
||||
|
||||
# 3. background region attends to global txt
|
||||
# Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired.
|
||||
regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = (
|
||||
background_region_mask.view(img_seq_len, 1)
|
||||
)
|
||||
|
||||
# Allow background regions to attend to themselves.
|
||||
regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(img_seq_len, 1)
|
||||
regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(1, img_seq_len)
|
||||
|
||||
# Convert attention mask to boolean.
|
||||
regional_attention_mask = regional_attention_mask > 0.5
|
||||
|
||||
return regional_attention_mask
|
||||
|
||||
@classmethod
|
||||
def _concat_regional_text_conditioning(
|
||||
cls,
|
||||
text_conditionings: list[FluxTextConditioning],
|
||||
) -> FluxRegionalTextConditioning:
|
||||
"""Concatenate regional text conditioning data into a single conditioning tensor (with associated masks)."""
|
||||
concat_t5_embeddings: list[torch.Tensor] = []
|
||||
concat_t5_embedding_ranges: list[Range] = []
|
||||
image_masks: list[torch.Tensor | None] = []
|
||||
|
||||
# Choose global CLIP embedding.
|
||||
# Use the first global prompt's CLIP embedding as the global CLIP embedding. If there is no global prompt, use
|
||||
# the first prompt's CLIP embedding.
|
||||
global_clip_embedding: torch.Tensor = text_conditionings[0].clip_embeddings
|
||||
for text_conditioning in text_conditionings:
|
||||
if text_conditioning.mask is None:
|
||||
global_clip_embedding = text_conditioning.clip_embeddings
|
||||
break
|
||||
|
||||
cur_t5_embedding_len = 0
|
||||
for text_conditioning in text_conditionings:
|
||||
concat_t5_embeddings.append(text_conditioning.t5_embeddings)
|
||||
|
||||
concat_t5_embedding_ranges.append(
|
||||
Range(start=cur_t5_embedding_len, end=cur_t5_embedding_len + text_conditioning.t5_embeddings.shape[1])
|
||||
)
|
||||
|
||||
image_masks.append(text_conditioning.mask)
|
||||
|
||||
cur_t5_embedding_len += text_conditioning.t5_embeddings.shape[1]
|
||||
|
||||
t5_embeddings = torch.cat(concat_t5_embeddings, dim=1)
|
||||
|
||||
# Initialize the txt_ids tensor.
|
||||
pos_bs, pos_t5_seq_len, _ = t5_embeddings.shape
|
||||
t5_txt_ids = torch.zeros(
|
||||
pos_bs, pos_t5_seq_len, 3, dtype=t5_embeddings.dtype, device=TorchDevice.choose_torch_device()
|
||||
)
|
||||
|
||||
return FluxRegionalTextConditioning(
|
||||
t5_embeddings=t5_embeddings,
|
||||
clip_embeddings=global_clip_embedding,
|
||||
t5_txt_ids=t5_txt_ids,
|
||||
image_masks=image_masks,
|
||||
t5_embedding_ranges=concat_t5_embedding_ranges,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def preprocess_regional_prompt_mask(
|
||||
mask: Optional[torch.Tensor], packed_height: int, packed_width: int, dtype: torch.dtype, device: torch.device
|
||||
) -> torch.Tensor:
|
||||
"""Preprocess a regional prompt mask to match the target height and width.
|
||||
If mask is None, returns a mask of all ones with the target height and width.
|
||||
If mask is not None, resizes the mask to the target height and width using 'nearest' interpolation.
|
||||
|
||||
packed_height and packed_width are the target height and width of the mask in the 'packed' latent space.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The processed mask. shape: (1, 1, packed_height * packed_width).
|
||||
"""
|
||||
|
||||
if mask is None:
|
||||
return torch.ones((1, 1, packed_height * packed_width), dtype=dtype, device=device)
|
||||
|
||||
mask = to_standard_float_mask(mask, out_dtype=dtype)
|
||||
|
||||
tf = torchvision.transforms.Resize(
|
||||
(packed_height, packed_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
|
||||
)
|
||||
|
||||
# Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w).
|
||||
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
|
||||
resized_mask = tf(mask)
|
||||
|
||||
# Flatten the height and width dimensions into a single image_seq_len dimension.
|
||||
return resized_mask.flatten(start_dim=2)
|
||||
@@ -5,10 +5,10 @@ from einops import rearrange
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attn_mask: Tensor | None = None) -> Tensor:
|
||||
q, k = apply_rope(q, k, pe)
|
||||
|
||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
x = rearrange(x, "B H L D -> B L (H D)")
|
||||
|
||||
return x
|
||||
@@ -24,12 +24,12 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
out = torch.einsum("...n,d->...nd", pos, omega)
|
||||
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
||||
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
||||
return out.float()
|
||||
return out.to(dtype=pos.dtype, device=pos.device)
|
||||
|
||||
|
||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
|
||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_ = xq.view(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.view(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
return xq_out.view(*xq.shape), xk_out.view(*xk.shape)
|
||||
|
||||
@@ -5,7 +5,11 @@ from dataclasses import dataclass
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from invokeai.backend.flux.custom_block_processor import CustomDoubleStreamBlockProcessor
|
||||
from invokeai.backend.flux.custom_block_processor import (
|
||||
CustomDoubleStreamBlockProcessor,
|
||||
CustomSingleStreamBlockProcessor,
|
||||
)
|
||||
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
|
||||
from invokeai.backend.flux.modules.layers import (
|
||||
DoubleStreamBlock,
|
||||
@@ -95,6 +99,7 @@ class Flux(nn.Module):
|
||||
controlnet_double_block_residuals: list[Tensor] | None,
|
||||
controlnet_single_block_residuals: list[Tensor] | None,
|
||||
ip_adapter_extensions: list[XLabsIPAdapterExtension],
|
||||
regional_prompting_extension: RegionalPromptingExtension,
|
||||
) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
@@ -117,7 +122,6 @@ class Flux(nn.Module):
|
||||
assert len(controlnet_double_block_residuals) == len(self.double_blocks)
|
||||
for block_index, block in enumerate(self.double_blocks):
|
||||
assert isinstance(block, DoubleStreamBlock)
|
||||
|
||||
img, txt = CustomDoubleStreamBlockProcessor.custom_double_block_forward(
|
||||
timestep_index=timestep_index,
|
||||
total_num_timesteps=total_num_timesteps,
|
||||
@@ -128,6 +132,7 @@ class Flux(nn.Module):
|
||||
vec=vec,
|
||||
pe=pe,
|
||||
ip_adapter_extensions=ip_adapter_extensions,
|
||||
regional_prompting_extension=regional_prompting_extension,
|
||||
)
|
||||
|
||||
if controlnet_double_block_residuals is not None:
|
||||
@@ -140,7 +145,17 @@ class Flux(nn.Module):
|
||||
assert len(controlnet_single_block_residuals) == len(self.single_blocks)
|
||||
|
||||
for block_index, block in enumerate(self.single_blocks):
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
assert isinstance(block, SingleStreamBlock)
|
||||
img = CustomSingleStreamBlockProcessor.custom_single_block_forward(
|
||||
timestep_index=timestep_index,
|
||||
total_num_timesteps=total_num_timesteps,
|
||||
block_index=block_index,
|
||||
block=block,
|
||||
img=img,
|
||||
vec=vec,
|
||||
pe=pe,
|
||||
regional_prompting_extension=regional_prompting_extension,
|
||||
)
|
||||
|
||||
if controlnet_single_block_residuals is not None:
|
||||
img[:, txt.shape[1] :, ...] += controlnet_single_block_residuals[block_index]
|
||||
|
||||
@@ -66,10 +66,7 @@ class RMSNorm(torch.nn.Module):
|
||||
self.scale = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
x_dtype = x.dtype
|
||||
x = x.float()
|
||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
||||
return (x * rrms).to(dtype=x_dtype) * self.scale
|
||||
return torch.nn.functional.rms_norm(x, self.scale.shape, self.scale, eps=1e-6)
|
||||
|
||||
|
||||
class QKNorm(torch.nn.Module):
|
||||
|
||||
36
invokeai/backend/flux/text_conditioning.py
Normal file
36
invokeai/backend/flux/text_conditioning.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Range
|
||||
|
||||
|
||||
@dataclass
|
||||
class FluxTextConditioning:
|
||||
t5_embeddings: torch.Tensor
|
||||
clip_embeddings: torch.Tensor
|
||||
# If mask is None, the prompt is a global prompt.
|
||||
mask: torch.Tensor | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FluxRegionalTextConditioning:
|
||||
# Concatenated text embeddings.
|
||||
# Shape: (1, concatenated_txt_seq_len, 4096)
|
||||
t5_embeddings: torch.Tensor
|
||||
# Shape: (1, concatenated_txt_seq_len, 3)
|
||||
t5_txt_ids: torch.Tensor
|
||||
|
||||
# Global CLIP embeddings.
|
||||
# Shape: (1, 768)
|
||||
clip_embeddings: torch.Tensor
|
||||
|
||||
# A binary mask indicating the regions of the image that the prompt should be applied to. If None, the prompt is a
|
||||
# global prompt.
|
||||
# image_masks[i] is the mask for the ith prompt.
|
||||
# image_masks[i] has shape (1, image_seq_len) and dtype torch.bool.
|
||||
image_masks: list[torch.Tensor | None]
|
||||
|
||||
# List of ranges that represent the embedding ranges for each mask.
|
||||
# t5_embedding_ranges[i] contains the range of the t5 embeddings that correspond to image_masks[i].
|
||||
t5_embedding_ranges: list[Range]
|
||||
@@ -8,7 +8,7 @@ from pathlib import Path
|
||||
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig, ModelLoaderBase
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_default import ModelCache
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry, ModelLoaderRegistryBase
|
||||
|
||||
# This registers the subclasses that implement loaders of specific model types
|
||||
|
||||
@@ -5,7 +5,6 @@ Base class for model loading in InvokeAI.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Generator, Optional, Tuple
|
||||
@@ -18,19 +17,17 @@ from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadedModelWithoutConfig:
|
||||
"""
|
||||
Context manager object that mediates transfer from RAM<->VRAM.
|
||||
"""Context manager object that mediates transfer from RAM<->VRAM.
|
||||
|
||||
This is a context manager object that has two distinct APIs:
|
||||
|
||||
1. Older API (deprecated):
|
||||
Use the LoadedModel object directly as a context manager.
|
||||
It will move the model into VRAM (on CUDA devices), and
|
||||
Use the LoadedModel object directly as a context manager. It will move the model into VRAM (on CUDA devices), and
|
||||
return the model in a form suitable for passing to torch.
|
||||
Example:
|
||||
```
|
||||
@@ -40,13 +37,9 @@ class LoadedModelWithoutConfig:
|
||||
```
|
||||
|
||||
2. Newer API (recommended):
|
||||
Call the LoadedModel's `model_on_device()` method in a
|
||||
context. It returns a tuple consisting of a copy of
|
||||
the model's state dict in CPU RAM followed by a copy
|
||||
of the model in VRAM. The state dict is provided to allow
|
||||
LoRAs and other model patchers to return the model to
|
||||
its unpatched state without expensive copy and restore
|
||||
operations.
|
||||
Call the LoadedModel's `model_on_device()` method in a context. It returns a tuple consisting of a copy of the
|
||||
model's state dict in CPU RAM followed by a copy of the model in VRAM. The state dict is provided to allow LoRAs and
|
||||
other model patchers to return the model to its unpatched state without expensive copy and restore operations.
|
||||
|
||||
Example:
|
||||
```
|
||||
@@ -55,43 +48,42 @@ class LoadedModelWithoutConfig:
|
||||
image = vae.decode(latents)[0]
|
||||
```
|
||||
|
||||
The state_dict should be treated as a read-only object and
|
||||
never modified. Also be aware that some loadable models do
|
||||
not have a state_dict, in which case this value will be None.
|
||||
The state_dict should be treated as a read-only object and never modified. Also be aware that some loadable models
|
||||
do not have a state_dict, in which case this value will be None.
|
||||
"""
|
||||
|
||||
_locker: ModelLockerBase
|
||||
def __init__(self, cache_record: CacheRecord, cache: ModelCache):
|
||||
self._cache_record = cache_record
|
||||
self._cache = cache
|
||||
|
||||
def __enter__(self) -> AnyModel:
|
||||
"""Context entry."""
|
||||
self._locker.lock()
|
||||
self._cache.lock(self._cache_record.key)
|
||||
return self.model
|
||||
|
||||
def __exit__(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""Context exit."""
|
||||
self._locker.unlock()
|
||||
self._cache.unlock(self._cache_record.key)
|
||||
|
||||
@contextmanager
|
||||
def model_on_device(self) -> Generator[Tuple[Optional[Dict[str, torch.Tensor]], AnyModel], None, None]:
|
||||
"""Return a tuple consisting of the model's state dict (if it exists) and the locked model on execution device."""
|
||||
locked_model = self._locker.lock()
|
||||
self._cache.lock(self._cache_record.key)
|
||||
try:
|
||||
state_dict = self._locker.get_state_dict()
|
||||
yield (state_dict, locked_model)
|
||||
yield (self._cache_record.cached_model.get_cpu_state_dict(), self._cache_record.cached_model.model)
|
||||
finally:
|
||||
self._locker.unlock()
|
||||
self._cache.unlock(self._cache_record.key)
|
||||
|
||||
@property
|
||||
def model(self) -> AnyModel:
|
||||
"""Return the model without locking it."""
|
||||
return self._locker.model
|
||||
return self._cache_record.cached_model.model
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadedModel(LoadedModelWithoutConfig):
|
||||
"""Context manager object that mediates transfer from RAM<->VRAM."""
|
||||
|
||||
config: Optional[AnyModelConfig] = None
|
||||
def __init__(self, config: Optional[AnyModelConfig], cache_record: CacheRecord, cache: ModelCache):
|
||||
super().__init__(cache_record=cache_record, cache=cache)
|
||||
self.config = config
|
||||
|
||||
|
||||
# TODO(MM2):
|
||||
@@ -110,7 +102,7 @@ class ModelLoaderBase(ABC):
|
||||
self,
|
||||
app_config: InvokeAIAppConfig,
|
||||
logger: Logger,
|
||||
ram_cache: ModelCacheBase[AnyModel],
|
||||
ram_cache: ModelCache,
|
||||
):
|
||||
"""Initialize the loader."""
|
||||
pass
|
||||
@@ -138,6 +130,6 @@ class ModelLoaderBase(ABC):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
||||
def ram_cache(self) -> ModelCache:
|
||||
"""Return the ram cache associated with this loader."""
|
||||
pass
|
||||
|
||||
@@ -14,7 +14,8 @@ from invokeai.backend.model_manager import (
|
||||
)
|
||||
from invokeai.backend.model_manager.config import DiffusersConfigBase
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache, get_model_cache_key
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
|
||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
@@ -28,7 +29,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
self,
|
||||
app_config: InvokeAIAppConfig,
|
||||
logger: Logger,
|
||||
ram_cache: ModelCacheBase[AnyModel],
|
||||
ram_cache: ModelCache,
|
||||
):
|
||||
"""Initialize the loader."""
|
||||
self._app_config = app_config
|
||||
@@ -54,11 +55,11 @@ class ModelLoader(ModelLoaderBase):
|
||||
raise InvalidModelConfigException(f"Files for model '{model_config.name}' not found at {model_path}")
|
||||
|
||||
with skip_torch_weight_init():
|
||||
locker = self._load_and_cache(model_config, submodel_type)
|
||||
return LoadedModel(config=model_config, _locker=locker)
|
||||
cache_record = self._load_and_cache(model_config, submodel_type)
|
||||
return LoadedModel(config=model_config, cache_record=cache_record, cache=self._ram_cache)
|
||||
|
||||
@property
|
||||
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
||||
def ram_cache(self) -> ModelCache:
|
||||
"""Return the ram cache associated with this loader."""
|
||||
return self._ram_cache
|
||||
|
||||
@@ -66,10 +67,10 @@ class ModelLoader(ModelLoaderBase):
|
||||
model_base = self._app_config.models_path
|
||||
return (model_base / config.path).resolve()
|
||||
|
||||
def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> ModelLockerBase:
|
||||
def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> CacheRecord:
|
||||
stats_name = ":".join([config.base, config.type, config.name, (submodel_type or "")])
|
||||
try:
|
||||
return self._ram_cache.get(config.key, submodel_type, stats_name=stats_name)
|
||||
return self._ram_cache.get(key=get_model_cache_key(config.key, submodel_type), stats_name=stats_name)
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
@@ -78,16 +79,11 @@ class ModelLoader(ModelLoaderBase):
|
||||
loaded_model = self._load_model(config, submodel_type)
|
||||
|
||||
self._ram_cache.put(
|
||||
config.key,
|
||||
submodel_type=submodel_type,
|
||||
get_model_cache_key(config.key, submodel_type),
|
||||
model=loaded_model,
|
||||
)
|
||||
|
||||
return self._ram_cache.get(
|
||||
key=config.key,
|
||||
submodel_type=submodel_type,
|
||||
stats_name=stats_name,
|
||||
)
|
||||
return self._ram_cache.get(key=get_model_cache_key(config.key, submodel_type), stats_name=stats_name)
|
||||
|
||||
def get_size_fs(
|
||||
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
"""Init file for ModelCache."""
|
||||
|
||||
from .model_cache_base import ModelCacheBase, CacheStats # noqa F401
|
||||
from .model_cache_default import ModelCache # noqa F401
|
||||
|
||||
_all__ = ["ModelCacheBase", "ModelCache", "CacheStats"]
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_only_full_load import (
|
||||
CachedModelOnlyFullLoad,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import (
|
||||
CachedModelWithPartialLoad,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheRecord:
|
||||
"""A class that represents a model in the model cache."""
|
||||
|
||||
# Cache key.
|
||||
key: str
|
||||
# Model in memory.
|
||||
cached_model: CachedModelWithPartialLoad | CachedModelOnlyFullLoad
|
||||
# If locks > 0, the model is actively being used, so we should do our best to keep it on the compute device.
|
||||
_locks: int = 0
|
||||
|
||||
def lock(self) -> None:
|
||||
self._locks += 1
|
||||
|
||||
def unlock(self) -> None:
|
||||
self._locks -= 1
|
||||
assert self._locks >= 0
|
||||
|
||||
@property
|
||||
def is_locked(self) -> bool:
|
||||
return self._locks > 0
|
||||
@@ -0,0 +1,15 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheStats(object):
|
||||
"""Collect statistics on cache performance."""
|
||||
|
||||
hits: int = 0 # cache hits
|
||||
misses: int = 0 # cache misses
|
||||
high_watermark: int = 0 # amount of cache used
|
||||
in_cache: int = 0 # number of models in cache
|
||||
cleared: int = 0 # number of models cleared to make space
|
||||
cache_size: int = 0 # total size of cache
|
||||
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)
|
||||
@@ -0,0 +1,81 @@
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class CachedModelOnlyFullLoad:
|
||||
"""A wrapper around a PyTorch model to handle full loads and unloads between the CPU and the compute device.
|
||||
|
||||
Note: "VRAM" is used throughout this class to refer to the memory on the compute device. It could be CUDA memory,
|
||||
MPS memory, etc.
|
||||
"""
|
||||
|
||||
def __init__(self, model: torch.nn.Module | Any, compute_device: torch.device, total_bytes: int):
|
||||
"""Initialize a CachedModelOnlyFullLoad.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module | Any): The model to wrap. Should be on the CPU.
|
||||
compute_device (torch.device): The compute device to move the model to.
|
||||
total_bytes (int): The total size (in bytes) of all the weights in the model.
|
||||
"""
|
||||
# model is often a torch.nn.Module, but could be any model type. Throughout this class, we handle both cases.
|
||||
self._model = model
|
||||
self._compute_device = compute_device
|
||||
self._total_bytes = total_bytes
|
||||
self._is_in_vram = False
|
||||
|
||||
@property
|
||||
def model(self) -> torch.nn.Module:
|
||||
return self._model
|
||||
|
||||
def get_cpu_state_dict(self) -> dict[str, torch.Tensor] | None:
|
||||
"""Get a read-only copy of the model's state dict in RAM."""
|
||||
# TODO(ryand): Document this better and implement it.
|
||||
return None
|
||||
|
||||
def total_bytes(self) -> int:
|
||||
"""Get the total size (in bytes) of all the weights in the model."""
|
||||
return self._total_bytes
|
||||
|
||||
def cur_vram_bytes(self) -> int:
|
||||
"""Get the size (in bytes) of the weights that are currently in VRAM."""
|
||||
if self._is_in_vram:
|
||||
return self._total_bytes
|
||||
else:
|
||||
return 0
|
||||
|
||||
def is_in_vram(self) -> bool:
|
||||
"""Return true if the model is currently in VRAM."""
|
||||
return self._is_in_vram
|
||||
|
||||
def full_load_to_vram(self) -> int:
|
||||
"""Load all weights into VRAM (if supported by the model).
|
||||
|
||||
Returns:
|
||||
The number of bytes loaded into VRAM.
|
||||
"""
|
||||
if self._is_in_vram:
|
||||
# Already in VRAM.
|
||||
return 0
|
||||
|
||||
if not hasattr(self._model, "to"):
|
||||
# Model doesn't support moving to a device.
|
||||
return 0
|
||||
|
||||
self._model.to(self._compute_device)
|
||||
self._is_in_vram = True
|
||||
return self._total_bytes
|
||||
|
||||
def full_unload_from_vram(self) -> int:
|
||||
"""Unload all weights from VRAM.
|
||||
|
||||
Returns:
|
||||
The number of bytes unloaded from VRAM.
|
||||
"""
|
||||
if not self._is_in_vram:
|
||||
# Already in RAM.
|
||||
return 0
|
||||
|
||||
self._model.to("cpu")
|
||||
self._is_in_vram = False
|
||||
return self._total_bytes
|
||||
@@ -0,0 +1,139 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_function_autocast_context import (
|
||||
add_autocast_to_module_forward,
|
||||
)
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
|
||||
|
||||
|
||||
def set_nested_attr(obj: object, attr: str, value: object):
|
||||
"""A helper function that extends setattr() to support nested attributes.
|
||||
|
||||
Example:
|
||||
set_nested_attr(model, "module.encoder.conv1.weight", new_conv1_weight)
|
||||
"""
|
||||
attrs = attr.split(".")
|
||||
for attr in attrs[:-1]:
|
||||
obj = getattr(obj, attr)
|
||||
setattr(obj, attrs[-1], value)
|
||||
|
||||
|
||||
class CachedModelWithPartialLoad:
|
||||
"""A wrapper around a PyTorch model to handle partial loads and unloads between the CPU and the compute device.
|
||||
|
||||
Note: "VRAM" is used throughout this class to refer to the memory on the compute device. It could be CUDA memory,
|
||||
MPS memory, etc.
|
||||
"""
|
||||
|
||||
def __init__(self, model: torch.nn.Module, compute_device: torch.device):
|
||||
self._model = model
|
||||
self._compute_device = compute_device
|
||||
|
||||
# A CPU read-only copy of the model's state dict.
|
||||
self._cpu_state_dict: dict[str, torch.Tensor] = model.state_dict()
|
||||
|
||||
# Monkey-patch the model to add autocasting to the model's forward method.
|
||||
add_autocast_to_module_forward(model, compute_device)
|
||||
|
||||
# TODO(ryand): Manage a read-only CPU copy of the model state dict.
|
||||
# TODO(ryand): Add memoization for total_bytes and cur_vram_bytes?
|
||||
|
||||
self._total_bytes = sum(calc_tensor_size(p) for p in self._model.parameters())
|
||||
self._cur_vram_bytes: int | None = None
|
||||
|
||||
@property
|
||||
def model(self) -> torch.nn.Module:
|
||||
return self._model
|
||||
|
||||
def get_cpu_state_dict(self) -> dict[str, torch.Tensor] | None:
|
||||
"""Get a read-only copy of the model's state dict in RAM."""
|
||||
# TODO(ryand): Document this better.
|
||||
return self._cpu_state_dict
|
||||
|
||||
def total_bytes(self) -> int:
|
||||
"""Get the total size (in bytes) of all the weights in the model."""
|
||||
return self._total_bytes
|
||||
|
||||
def cur_vram_bytes(self) -> int:
|
||||
"""Get the size (in bytes) of the weights that are currently in VRAM."""
|
||||
if self._cur_vram_bytes is None:
|
||||
self._cur_vram_bytes = sum(
|
||||
calc_tensor_size(p) for p in self._model.parameters() if p.device.type == self._compute_device.type
|
||||
)
|
||||
return self._cur_vram_bytes
|
||||
|
||||
def full_load_to_vram(self) -> int:
|
||||
"""Load all weights into VRAM."""
|
||||
return self.partial_load_to_vram(self.total_bytes())
|
||||
|
||||
def full_unload_from_vram(self) -> int:
|
||||
"""Unload all weights from VRAM."""
|
||||
return self.partial_unload_from_vram(self.total_bytes())
|
||||
|
||||
@torch.no_grad()
|
||||
def partial_load_to_vram(self, vram_bytes_to_load: int) -> int:
|
||||
"""Load more weights into VRAM without exceeding vram_bytes_to_load.
|
||||
|
||||
Returns:
|
||||
The number of bytes loaded into VRAM.
|
||||
"""
|
||||
vram_bytes_loaded = 0
|
||||
|
||||
# TODO(ryand): Iterate over buffers too?
|
||||
for key, param in self._model.named_parameters():
|
||||
# Skip parameters that are already on the compute device.
|
||||
if param.device.type == self._compute_device.type:
|
||||
continue
|
||||
|
||||
# Check the size of the parameter.
|
||||
param_size = calc_tensor_size(param)
|
||||
if vram_bytes_loaded + param_size > vram_bytes_to_load:
|
||||
# TODO(ryand): Should we just break here? If we couldn't fit this parameter into VRAM, is it really
|
||||
# worth continuing to search for a smaller parameter that would fit?
|
||||
continue
|
||||
|
||||
# Copy the parameter to the compute device.
|
||||
# We use the 'overwrite' strategy from torch.nn.Module._apply().
|
||||
# TODO(ryand): For some edge cases (e.g. quantized models?), we may need to support other strategies (e.g.
|
||||
# swap).
|
||||
assert isinstance(param, torch.nn.Parameter)
|
||||
assert param.is_leaf
|
||||
out_param = torch.nn.Parameter(param.to(self._compute_device, copy=True), requires_grad=param.requires_grad)
|
||||
set_nested_attr(self._model, key, out_param)
|
||||
# We did not port the param.grad handling from torch.nn.Module._apply(), because we do not expect to be
|
||||
# handling gradients. We assert that this assumption is true.
|
||||
assert param.grad is None
|
||||
|
||||
vram_bytes_loaded += param_size
|
||||
|
||||
if self._cur_vram_bytes is not None:
|
||||
self._cur_vram_bytes += vram_bytes_loaded
|
||||
|
||||
return vram_bytes_loaded
|
||||
|
||||
@torch.no_grad()
|
||||
def partial_unload_from_vram(self, vram_bytes_to_free: int) -> int:
|
||||
"""Unload weights from VRAM until vram_bytes_to_free bytes are freed. Or the entire model is unloaded.
|
||||
|
||||
Returns:
|
||||
The number of bytes unloaded from VRAM.
|
||||
"""
|
||||
vram_bytes_freed = 0
|
||||
|
||||
# TODO(ryand): Iterate over buffers too?
|
||||
for key, param in self._model.named_parameters():
|
||||
if vram_bytes_freed >= vram_bytes_to_free:
|
||||
break
|
||||
|
||||
if param.device.type != self._compute_device.type:
|
||||
continue
|
||||
|
||||
# Create a new parameter, but inject the existing CPU tensor into it.
|
||||
out_param = torch.nn.Parameter(self._cpu_state_dict[key], requires_grad=param.requires_grad)
|
||||
set_nested_attr(self._model, key, out_param)
|
||||
vram_bytes_freed += calc_tensor_size(param)
|
||||
|
||||
if self._cur_vram_bytes is not None:
|
||||
self._cur_vram_bytes -= vram_bytes_freed
|
||||
|
||||
return vram_bytes_freed
|
||||
534
invokeai/backend/model_manager/load/model_cache/model_cache.py
Normal file
534
invokeai/backend/model_manager/load/model_cache/model_cache.py
Normal file
@@ -0,0 +1,534 @@
|
||||
import gc
|
||||
from logging import Logger
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel, SubModelType
|
||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
|
||||
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_only_full_load import (
|
||||
CachedModelOnlyFullLoad,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import (
|
||||
CachedModelWithPartialLoad,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.backend.util.prefix_logger_adapter import PrefixedLoggerAdapter
|
||||
|
||||
# Size of a GB in bytes.
|
||||
GB = 2**30
|
||||
|
||||
# Size of a MB in bytes.
|
||||
MB = 2**20
|
||||
|
||||
|
||||
# TODO(ryand): Where should this go? The ModelCache shouldn't be concerned with submodels.
|
||||
def get_model_cache_key(model_key: str, submodel_type: Optional[SubModelType] = None) -> str:
|
||||
"""Get the cache key for a model based on the optional submodel type."""
|
||||
if submodel_type:
|
||||
return f"{model_key}:{submodel_type.value}"
|
||||
else:
|
||||
return model_key
|
||||
|
||||
|
||||
class ModelCache:
|
||||
"""A cache for managing models in memory.
|
||||
|
||||
The cache is based on two levels of model storage:
|
||||
- execution_device: The device where most models are executed (typically "cuda", "mps", or "cpu").
|
||||
- storage_device: The device where models are offloaded when not in active use (typically "cpu").
|
||||
|
||||
The model cache is based on the following assumptions:
|
||||
- storage_device_mem_size > execution_device_mem_size
|
||||
- disk_to_storage_device_transfer_time >> storage_device_to_execution_device_transfer_time
|
||||
|
||||
A copy of all models in the cache is always kept on the storage_device. A subset of the models also have a copy on
|
||||
the execution_device.
|
||||
|
||||
Models are moved between the storage_device and the execution_device as necessary. Cache size limits are enforced
|
||||
on both the storage_device and the execution_device. The execution_device cache uses a smallest-first offload
|
||||
policy. The storage_device cache uses a least-recently-used (LRU) offload policy.
|
||||
|
||||
Note: Neither of these offload policies has really been compared against alternatives. It's likely that different
|
||||
policies would be better, although the optimal policies are likely heavily dependent on usage patterns and HW
|
||||
configuration.
|
||||
|
||||
The cache returns context manager generators designed to load the model into the execution device (often GPU) within
|
||||
the context, and unload outside the context.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
cache = ModelCache(max_cache_size=7.5, max_vram_cache_size=6.0)
|
||||
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1:
|
||||
do_something_on_gpu(SD1)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_cache_size: float,
|
||||
max_vram_cache_size: float,
|
||||
execution_device: torch.device = torch.device("cuda"),
|
||||
storage_device: torch.device = torch.device("cpu"),
|
||||
lazy_offloading: bool = True,
|
||||
log_memory_usage: bool = False,
|
||||
logger: Optional[Logger] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the model RAM cache.
|
||||
|
||||
:param max_cache_size: Maximum size of the storage_device cache in GBs.
|
||||
:param max_vram_cache_size: Maximum size of the execution_device cache in GBs.
|
||||
:param execution_device: Torch device to load active model into [torch.device('cuda')]
|
||||
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
||||
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
||||
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
|
||||
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
|
||||
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
|
||||
behaviour.
|
||||
:param logger: InvokeAILogger to use (otherwise creates one)
|
||||
"""
|
||||
# allow lazy offloading only when vram cache enabled
|
||||
# TODO(ryand): Think about what lazy_offloading should mean in the new model cache.
|
||||
self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0
|
||||
self._max_cache_size: float = max_cache_size
|
||||
self._max_vram_cache_size: float = max_vram_cache_size
|
||||
self._execution_device: torch.device = execution_device
|
||||
self._storage_device: torch.device = storage_device
|
||||
self._logger = PrefixedLoggerAdapter(
|
||||
logger or InvokeAILogger.get_logger(self.__class__.__name__), "MODEL CACHE"
|
||||
)
|
||||
self._log_memory_usage = log_memory_usage
|
||||
self._stats: Optional[CacheStats] = None
|
||||
|
||||
self._cached_models: Dict[str, CacheRecord] = {}
|
||||
self._cache_stack: List[str] = []
|
||||
|
||||
@property
|
||||
def max_cache_size(self) -> float:
|
||||
"""Return the cap on cache size."""
|
||||
return self._max_cache_size
|
||||
|
||||
@max_cache_size.setter
|
||||
def max_cache_size(self, value: float) -> None:
|
||||
"""Set the cap on cache size."""
|
||||
self._max_cache_size = value
|
||||
|
||||
@property
|
||||
def max_vram_cache_size(self) -> float:
|
||||
"""Return the cap on vram cache size."""
|
||||
return self._max_vram_cache_size
|
||||
|
||||
@max_vram_cache_size.setter
|
||||
def max_vram_cache_size(self, value: float) -> None:
|
||||
"""Set the cap on vram cache size."""
|
||||
self._max_vram_cache_size = value
|
||||
|
||||
@property
|
||||
def stats(self) -> Optional[CacheStats]:
|
||||
"""Return collected CacheStats object."""
|
||||
return self._stats
|
||||
|
||||
@stats.setter
|
||||
def stats(self, stats: CacheStats) -> None:
|
||||
"""Set the CacheStats object for collecting cache statistics."""
|
||||
self._stats = stats
|
||||
|
||||
def put(self, key: str, model: AnyModel) -> None:
|
||||
"""Add a model to the cache."""
|
||||
if key in self._cached_models:
|
||||
self._logger.debug(
|
||||
f"Attempted to add model {key} ({model.__class__.__name__}), but it already exists in the cache. No action necessary."
|
||||
)
|
||||
return
|
||||
|
||||
size = calc_model_size_by_data(self._logger, model)
|
||||
self.make_room(size)
|
||||
|
||||
# Wrap model.
|
||||
if isinstance(model, torch.nn.Module):
|
||||
wrapped_model = CachedModelWithPartialLoad(model, self._execution_device)
|
||||
else:
|
||||
wrapped_model = CachedModelOnlyFullLoad(model, self._execution_device, size)
|
||||
|
||||
# running_on_cpu = self._execution_device == torch.device("cpu")
|
||||
# state_dict = model.state_dict() if isinstance(model, torch.nn.Module) and not running_on_cpu else None
|
||||
cache_record = CacheRecord(key=key, cached_model=wrapped_model)
|
||||
self._cached_models[key] = cache_record
|
||||
self._cache_stack.append(key)
|
||||
self._logger.debug(
|
||||
f"Added model {key} (Type: {model.__class__.__name__}, Wrap mode: {wrapped_model.__class__.__name__}, Model size: {size/MB:.2f}MB)"
|
||||
)
|
||||
|
||||
def get(self, key: str, stats_name: Optional[str] = None) -> CacheRecord:
|
||||
"""Retrieve a model from the cache.
|
||||
|
||||
:param key: Model key
|
||||
:param stats_name: A human-readable id for the model for the purposes of stats reporting.
|
||||
|
||||
Raises IndexError if the model is not in the cache.
|
||||
"""
|
||||
if key in self._cached_models:
|
||||
if self.stats:
|
||||
self.stats.hits += 1
|
||||
else:
|
||||
if self.stats:
|
||||
self.stats.misses += 1
|
||||
self._logger.debug(f"Cache miss: {key}")
|
||||
raise IndexError(f"The model with key {key} is not in the cache.")
|
||||
|
||||
cache_entry = self._cached_models[key]
|
||||
|
||||
# more stats
|
||||
if self.stats:
|
||||
stats_name = stats_name or key
|
||||
self.stats.cache_size = int(self._max_cache_size * GB)
|
||||
self.stats.high_watermark = max(self.stats.high_watermark, self._get_ram_in_use())
|
||||
self.stats.in_cache = len(self._cached_models)
|
||||
self.stats.loaded_model_sizes[stats_name] = max(
|
||||
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.cached_model.total_bytes()
|
||||
)
|
||||
|
||||
# this moves the entry to the top (right end) of the stack
|
||||
self._cache_stack = [k for k in self._cache_stack if k != key]
|
||||
self._cache_stack.append(key)
|
||||
|
||||
self._logger.debug(f"Cache hit: {key} (Type: {cache_entry.cached_model.model.__class__.__name__})")
|
||||
|
||||
return cache_entry
|
||||
|
||||
def lock(self, key: str) -> None:
|
||||
"""Lock a model for use and move it into VRAM."""
|
||||
cache_entry = self._cached_models[key]
|
||||
cache_entry.lock()
|
||||
|
||||
self._logger.debug(f"Locking model {key} (Type: {cache_entry.cached_model.model.__class__.__name__})")
|
||||
|
||||
try:
|
||||
self._load_locked_model(cache_entry)
|
||||
self._logger.debug(
|
||||
f"Finished locking model {key} (Type: {cache_entry.cached_model.model.__class__.__name__})"
|
||||
)
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
self._logger.warning("Insufficient GPU memory to load model. Aborting")
|
||||
cache_entry.unlock()
|
||||
raise
|
||||
except Exception:
|
||||
cache_entry.unlock()
|
||||
raise
|
||||
|
||||
self._log_cache_state()
|
||||
|
||||
def unlock(self, key: str) -> None:
|
||||
"""Unlock a model."""
|
||||
cache_entry = self._cached_models[key]
|
||||
cache_entry.unlock()
|
||||
self._logger.debug(f"Unlocked model {key} (Type: {cache_entry.cached_model.model.__class__.__name__})")
|
||||
|
||||
def _load_locked_model(self, cache_entry: CacheRecord) -> None:
|
||||
"""Helper function for self.lock(). Loads a locked model into VRAM."""
|
||||
vram_available = self._get_vram_available()
|
||||
|
||||
# The amount of additional VRAM that will be used if we fully load the model into VRAM.
|
||||
model_cur_vram_bytes = cache_entry.cached_model.cur_vram_bytes()
|
||||
model_total_bytes = cache_entry.cached_model.total_bytes()
|
||||
model_vram_needed = model_total_bytes - model_cur_vram_bytes
|
||||
|
||||
self._logger.debug(
|
||||
f"Before unloading: {self._get_vram_state_str(model_cur_vram_bytes, model_total_bytes, vram_available)}"
|
||||
)
|
||||
|
||||
# Make room for the model in VRAM.
|
||||
# 1. If the model can fit entirely in VRAM, then make enough room for it to be loaded fully.
|
||||
# 2. If the model can't fit fully into VRAM, then unload all other models and load as much of the model as
|
||||
# possible.
|
||||
vram_bytes_freed = self._offload_unlocked_models(model_vram_needed)
|
||||
self._logger.debug(f"Unloaded models (if necessary): vram_bytes_freed={(vram_bytes_freed/MB):.2f}MB")
|
||||
|
||||
# Check the updated vram_available after offloading.
|
||||
vram_available = self._get_vram_available()
|
||||
self._logger.debug(
|
||||
f"After unloading: {self._get_vram_state_str(model_cur_vram_bytes, model_total_bytes, vram_available)}"
|
||||
)
|
||||
|
||||
# Move as much of the model as possible into VRAM.
|
||||
model_bytes_loaded = 0
|
||||
if isinstance(cache_entry.cached_model, CachedModelWithPartialLoad):
|
||||
model_bytes_loaded = cache_entry.cached_model.partial_load_to_vram(vram_available)
|
||||
elif isinstance(cache_entry.cached_model, CachedModelOnlyFullLoad): # type: ignore
|
||||
# Partial load is not supported, so we have not choice but to try and fit it all into VRAM.
|
||||
model_bytes_loaded = cache_entry.cached_model.full_load_to_vram()
|
||||
else:
|
||||
raise ValueError(f"Unsupported cached model type: {type(cache_entry.cached_model)}")
|
||||
|
||||
model_cur_vram_bytes = cache_entry.cached_model.cur_vram_bytes()
|
||||
vram_available = self._get_vram_available()
|
||||
self._logger.debug(f"Loaded model onto execution device: model_bytes_loaded={(model_bytes_loaded/MB):.2f}MB, ")
|
||||
self._logger.debug(
|
||||
f"After loading: {self._get_vram_state_str(model_cur_vram_bytes, model_total_bytes, vram_available)}"
|
||||
)
|
||||
|
||||
def _get_vram_available(self) -> int:
|
||||
"""Get the amount of VRAM available in the cache."""
|
||||
return int(self._max_vram_cache_size * GB) - self._get_vram_in_use()
|
||||
|
||||
def _get_vram_in_use(self) -> int:
|
||||
"""Get the amount of VRAM currently in use."""
|
||||
return sum(ce.cached_model.cur_vram_bytes() for ce in self._cached_models.values())
|
||||
|
||||
def _get_ram_available(self) -> int:
|
||||
"""Get the amount of RAM available in the cache."""
|
||||
return int(self._max_cache_size * GB) - self._get_ram_in_use()
|
||||
|
||||
def _get_ram_in_use(self) -> int:
|
||||
"""Get the amount of RAM currently in use."""
|
||||
return sum(ce.cached_model.total_bytes() for ce in self._cached_models.values())
|
||||
|
||||
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
|
||||
if self._log_memory_usage:
|
||||
return MemorySnapshot.capture()
|
||||
return None
|
||||
|
||||
def _get_vram_state_str(self, model_cur_vram_bytes: int, model_total_bytes: int, vram_available: int) -> str:
|
||||
"""Helper function for preparing a VRAM state log string."""
|
||||
model_cur_vram_bytes_percent = model_cur_vram_bytes / model_total_bytes if model_total_bytes > 0 else 0
|
||||
return (
|
||||
f"model_total={model_total_bytes/MB:.0f} MB, "
|
||||
+ f"model_vram={model_cur_vram_bytes/MB:.0f} MB ({model_cur_vram_bytes_percent:.1%} %), "
|
||||
+ f"vram_total={int(self._max_vram_cache_size * GB)/MB:.0f} MB, "
|
||||
+ f"vram_available={(vram_available/MB):.0f} MB, "
|
||||
)
|
||||
|
||||
def _offload_unlocked_models(self, vram_bytes_to_free: int) -> int:
|
||||
"""Offload models from the execution_device until vram_bytes_to_free bytes are freed, or all models are
|
||||
offloaded. Of course, locked models are not offloaded.
|
||||
|
||||
Returns:
|
||||
int: The number of bytes freed.
|
||||
"""
|
||||
self._logger.debug(f"Offloading unlocked models with goal of freeing {vram_bytes_to_free/MB:.2f}MB of VRAM.")
|
||||
vram_bytes_freed = 0
|
||||
# TODO(ryand): Give more thought to the offloading policy used here.
|
||||
cache_entries_increasing_size = sorted(self._cached_models.values(), key=lambda x: x.cached_model.total_bytes())
|
||||
for cache_entry in cache_entries_increasing_size:
|
||||
if vram_bytes_freed >= vram_bytes_to_free:
|
||||
break
|
||||
if cache_entry.is_locked:
|
||||
continue
|
||||
|
||||
if isinstance(cache_entry.cached_model, CachedModelWithPartialLoad):
|
||||
cache_entry_bytes_freed = cache_entry.cached_model.partial_unload_from_vram(
|
||||
vram_bytes_to_free - vram_bytes_freed
|
||||
)
|
||||
elif isinstance(cache_entry.cached_model, CachedModelOnlyFullLoad): # type: ignore
|
||||
cache_entry_bytes_freed = cache_entry.cached_model.full_unload_from_vram()
|
||||
else:
|
||||
raise ValueError(f"Unsupported cached model type: {type(cache_entry.cached_model)}")
|
||||
if cache_entry_bytes_freed > 0:
|
||||
self._logger.debug(
|
||||
f"Unloaded {cache_entry.key} from VRAM to free {(cache_entry_bytes_freed/MB):.0f} MB."
|
||||
)
|
||||
vram_bytes_freed += cache_entry_bytes_freed
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
return vram_bytes_freed
|
||||
|
||||
# def _move_model_to_device(self, cache_entry: CacheRecord, target_device: torch.device) -> None:
|
||||
# """Move model into the indicated device.
|
||||
|
||||
# :param cache_entry: The CacheRecord for the model
|
||||
# :param target_device: The torch.device to move the model into
|
||||
|
||||
# May raise a torch.cuda.OutOfMemoryError
|
||||
# """
|
||||
# self._logger.debug(f"Called to move {cache_entry.key} to {target_device}")
|
||||
# source_device = cache_entry.device
|
||||
|
||||
# # Note: We compare device types only so that 'cuda' == 'cuda:0'.
|
||||
# # This would need to be revised to support multi-GPU.
|
||||
# if torch.device(source_device).type == torch.device(target_device).type:
|
||||
# return
|
||||
|
||||
# # Some models don't have a `to` method, in which case they run in RAM/CPU.
|
||||
# if not hasattr(cache_entry.model, "to"):
|
||||
# return
|
||||
|
||||
# # This roundabout method for moving the model around is done to avoid
|
||||
# # the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
|
||||
# # When moving to VRAM, we copy (not move) each element of the state dict from
|
||||
# # RAM to a new state dict in VRAM, and then inject it into the model.
|
||||
# # This operation is slightly faster than running `to()` on the whole model.
|
||||
# #
|
||||
# # When the model needs to be removed from VRAM we simply delete the copy
|
||||
# # of the state dict in VRAM, and reinject the state dict that is cached
|
||||
# # in RAM into the model. So this operation is very fast.
|
||||
# start_model_to_time = time.time()
|
||||
# snapshot_before = self._capture_memory_snapshot()
|
||||
|
||||
# try:
|
||||
# if cache_entry.state_dict is not None:
|
||||
# assert hasattr(cache_entry.model, "load_state_dict")
|
||||
# if target_device == self._storage_device:
|
||||
# cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True)
|
||||
# else:
|
||||
# new_dict: Dict[str, torch.Tensor] = {}
|
||||
# for k, v in cache_entry.state_dict.items():
|
||||
# new_dict[k] = v.to(target_device, copy=True)
|
||||
# cache_entry.model.load_state_dict(new_dict, assign=True)
|
||||
# cache_entry.model.to(target_device)
|
||||
# cache_entry.device = target_device
|
||||
# except Exception as e: # blow away cache entry
|
||||
# self._delete_cache_entry(cache_entry)
|
||||
# raise e
|
||||
|
||||
# snapshot_after = self._capture_memory_snapshot()
|
||||
# end_model_to_time = time.time()
|
||||
# self._logger.debug(
|
||||
# f"Moved model '{cache_entry.key}' from {source_device} to"
|
||||
# f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
|
||||
# f"Estimated model size: {(cache_entry.size/GB):.3f} GB."
|
||||
# f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||
# )
|
||||
|
||||
# if (
|
||||
# snapshot_before is not None
|
||||
# and snapshot_after is not None
|
||||
# and snapshot_before.vram is not None
|
||||
# and snapshot_after.vram is not None
|
||||
# ):
|
||||
# vram_change = abs(snapshot_before.vram - snapshot_after.vram)
|
||||
|
||||
# # If the estimated model size does not match the change in VRAM, log a warning.
|
||||
# if not math.isclose(
|
||||
# vram_change,
|
||||
# cache_entry.size,
|
||||
# rel_tol=0.1,
|
||||
# abs_tol=10 * MB,
|
||||
# ):
|
||||
# self._logger.debug(
|
||||
# f"Moving model '{cache_entry.key}' from {source_device} to"
|
||||
# f" {target_device} caused an unexpected change in VRAM usage. The model's"
|
||||
# " estimated size may be incorrect. Estimated model size:"
|
||||
# f" {(cache_entry.size/GB):.3f} GB.\n"
|
||||
# f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||
# )
|
||||
|
||||
def _log_cache_state(self, title: str = "Model cache state:", include_entry_details: bool = True):
|
||||
ram_size_bytes = self._max_cache_size * GB
|
||||
ram_in_use_bytes = self._get_ram_in_use()
|
||||
ram_in_use_bytes_percent = ram_in_use_bytes / ram_size_bytes if ram_size_bytes > 0 else 0
|
||||
ram_available_bytes = self._get_ram_available()
|
||||
ram_available_bytes_percent = ram_available_bytes / ram_size_bytes if ram_size_bytes > 0 else 0
|
||||
|
||||
vram_size_bytes = self._max_vram_cache_size * GB
|
||||
vram_in_use_bytes = self._get_vram_in_use()
|
||||
vram_in_use_bytes_percent = vram_in_use_bytes / vram_size_bytes if vram_size_bytes > 0 else 0
|
||||
vram_available_bytes = self._get_vram_available()
|
||||
vram_available_bytes_percent = vram_available_bytes / vram_size_bytes if vram_size_bytes > 0 else 0
|
||||
|
||||
log = f"{title}\n"
|
||||
|
||||
log_format = " {:<30} Limit: {:>7.1f} MB, Used: {:>7.1f} MB ({:>5.1%}), Available: {:>7.1f} MB ({:>5.1%})\n"
|
||||
log += log_format.format(
|
||||
f"Storage Device ({self._storage_device.type})",
|
||||
ram_size_bytes / MB,
|
||||
ram_in_use_bytes / MB,
|
||||
ram_in_use_bytes_percent,
|
||||
ram_available_bytes / MB,
|
||||
ram_available_bytes_percent,
|
||||
)
|
||||
log += log_format.format(
|
||||
f"Compute Device ({self._execution_device.type})",
|
||||
vram_size_bytes / MB,
|
||||
vram_in_use_bytes / MB,
|
||||
vram_in_use_bytes_percent,
|
||||
vram_available_bytes / MB,
|
||||
vram_available_bytes_percent,
|
||||
)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
log += " {:<30} {} MB\n".format("CUDA Memory Allocated:", torch.cuda.memory_allocated() / MB)
|
||||
log += " {:<30} {}\n".format("Total models:", len(self._cached_models))
|
||||
|
||||
if include_entry_details and len(self._cached_models) > 0:
|
||||
log += " Models:\n"
|
||||
log_format = (
|
||||
" {:<80} total={:>7.1f} MB, vram={:>7.1f} MB ({:>5.1%}), ram={:>7.1f} MB ({:>5.1%}), locked={}\n"
|
||||
)
|
||||
for cache_record in self._cached_models.values():
|
||||
total_bytes = cache_record.cached_model.total_bytes()
|
||||
cur_vram_bytes = cache_record.cached_model.cur_vram_bytes()
|
||||
cur_vram_bytes_percent = cur_vram_bytes / total_bytes if total_bytes > 0 else 0
|
||||
cur_ram_bytes = total_bytes - cur_vram_bytes
|
||||
cur_ram_bytes_percent = cur_ram_bytes / total_bytes if total_bytes > 0 else 0
|
||||
|
||||
log += log_format.format(
|
||||
f"{cache_record.key} ({cache_record.cached_model.model.__class__.__name__}):",
|
||||
total_bytes / MB,
|
||||
cur_vram_bytes / MB,
|
||||
cur_vram_bytes_percent,
|
||||
cur_ram_bytes / MB,
|
||||
cur_ram_bytes_percent,
|
||||
cache_record.is_locked,
|
||||
)
|
||||
|
||||
self._logger.debug(log)
|
||||
|
||||
def make_room(self, bytes_needed: int) -> None:
|
||||
"""Make enough room in the cache to accommodate a new model of indicated size.
|
||||
|
||||
Note: This function deletes all of the cache's internal references to a model in order to free it. If there are
|
||||
external references to the model, there's nothing that the cache can do about it, and those models will not be
|
||||
garbage-collected.
|
||||
"""
|
||||
self._logger.debug(f"Making room for {bytes_needed/MB:.2f}MB of RAM.")
|
||||
self._log_cache_state(title="Before dropping models:")
|
||||
|
||||
ram_bytes_available = self._get_ram_available()
|
||||
ram_bytes_to_free = max(0, bytes_needed - ram_bytes_available)
|
||||
|
||||
ram_bytes_freed = 0
|
||||
pos = 0
|
||||
models_cleared = 0
|
||||
while ram_bytes_freed < ram_bytes_to_free and pos < len(self._cache_stack):
|
||||
model_key = self._cache_stack[pos]
|
||||
cache_entry = self._cached_models[model_key]
|
||||
|
||||
if not cache_entry.is_locked:
|
||||
ram_bytes_freed += cache_entry.cached_model.total_bytes()
|
||||
self._logger.debug(
|
||||
f"Dropping {model_key} from RAM cache to free {(cache_entry.cached_model.total_bytes()/MB):.2f}MB."
|
||||
)
|
||||
self._delete_cache_entry(cache_entry)
|
||||
del cache_entry
|
||||
models_cleared += 1
|
||||
else:
|
||||
pos += 1
|
||||
|
||||
if models_cleared > 0:
|
||||
# There would likely be some 'garbage' to be collected regardless of whether a model was cleared or not, but
|
||||
# there is a significant time cost to calling `gc.collect()`, so we want to use it sparingly. (The time cost
|
||||
# is high even if no garbage gets collected.)
|
||||
#
|
||||
# Calling gc.collect(...) when a model is cleared seems like a good middle-ground:
|
||||
# - If models had to be cleared, it's a signal that we are close to our memory limit.
|
||||
# - If models were cleared, there's a good chance that there's a significant amount of garbage to be
|
||||
# collected.
|
||||
#
|
||||
# Keep in mind that gc is only responsible for handling reference cycles. Most objects should be cleaned up
|
||||
# immediately when their reference count hits 0.
|
||||
if self.stats:
|
||||
self.stats.cleared = models_cleared
|
||||
gc.collect()
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
self._logger.debug(f"Dropped {models_cleared} models to free {ram_bytes_freed/MB:.2f}MB of RAM.")
|
||||
self._log_cache_state(title="After dropping models:")
|
||||
|
||||
def _delete_cache_entry(self, cache_entry: CacheRecord) -> None:
|
||||
self._cache_stack.remove(cache_entry.key)
|
||||
del self._cached_models[cache_entry.key]
|
||||
@@ -1,221 +0,0 @@
|
||||
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
|
||||
# TODO: Add Stalker's proper name to copyright
|
||||
"""
|
||||
Manage a RAM cache of diffusion/transformer models for fast switching.
|
||||
They are moved between GPU VRAM and CPU RAM as necessary. If the cache
|
||||
grows larger than a preset maximum, then the least recently used
|
||||
model will be cleared and (re)loaded from disk when next needed.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from logging import Logger
|
||||
from typing import Dict, Generic, Optional, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.config import AnyModel, SubModelType
|
||||
|
||||
|
||||
class ModelLockerBase(ABC):
|
||||
"""Base class for the model locker used by the loader."""
|
||||
|
||||
@abstractmethod
|
||||
def lock(self) -> AnyModel:
|
||||
"""Lock the contained model and move it into VRAM."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def unlock(self) -> None:
|
||||
"""Unlock the contained model, and remove it from VRAM."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
|
||||
"""Return the state dict (if any) for the cached model."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def model(self) -> AnyModel:
|
||||
"""Return the model."""
|
||||
pass
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheRecord(Generic[T]):
|
||||
"""
|
||||
Elements of the cache:
|
||||
|
||||
key: Unique key for each model, same as used in the models database.
|
||||
model: Model in memory.
|
||||
state_dict: A read-only copy of the model's state dict in RAM. It will be
|
||||
used as a template for creating a copy in the VRAM.
|
||||
size: Size of the model
|
||||
loaded: True if the model's state dict is currently in VRAM
|
||||
|
||||
Before a model is executed, the state_dict template is copied into VRAM,
|
||||
and then injected into the model. When the model is finished, the VRAM
|
||||
copy of the state dict is deleted, and the RAM version is reinjected
|
||||
into the model.
|
||||
|
||||
The state_dict should be treated as a read-only attribute. Do not attempt
|
||||
to patch or otherwise modify it. Instead, patch the copy of the state_dict
|
||||
after it is loaded into the execution device (e.g. CUDA) using the `LoadedModel`
|
||||
context manager call `model_on_device()`.
|
||||
"""
|
||||
|
||||
key: str
|
||||
model: T
|
||||
device: torch.device
|
||||
state_dict: Optional[Dict[str, torch.Tensor]]
|
||||
size: int
|
||||
loaded: bool = False
|
||||
_locks: int = 0
|
||||
|
||||
def lock(self) -> None:
|
||||
"""Lock this record."""
|
||||
self._locks += 1
|
||||
|
||||
def unlock(self) -> None:
|
||||
"""Unlock this record."""
|
||||
self._locks -= 1
|
||||
assert self._locks >= 0
|
||||
|
||||
@property
|
||||
def locked(self) -> bool:
|
||||
"""Return true if record is locked."""
|
||||
return self._locks > 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheStats(object):
|
||||
"""Collect statistics on cache performance."""
|
||||
|
||||
hits: int = 0 # cache hits
|
||||
misses: int = 0 # cache misses
|
||||
high_watermark: int = 0 # amount of cache used
|
||||
in_cache: int = 0 # number of models in cache
|
||||
cleared: int = 0 # number of models cleared to make space
|
||||
cache_size: int = 0 # total size of cache
|
||||
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)
|
||||
|
||||
|
||||
class ModelCacheBase(ABC, Generic[T]):
|
||||
"""Virtual base class for RAM model cache."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def storage_device(self) -> torch.device:
|
||||
"""Return the storage device (e.g. "CPU" for RAM)."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def execution_device(self) -> torch.device:
|
||||
"""Return the exection device (e.g. "cuda" for VRAM)."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def lazy_offloading(self) -> bool:
|
||||
"""Return true if the cache is configured to lazily offload models in VRAM."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def max_cache_size(self) -> float:
|
||||
"""Return the maximum size the RAM cache can grow to."""
|
||||
pass
|
||||
|
||||
@max_cache_size.setter
|
||||
@abstractmethod
|
||||
def max_cache_size(self, value: float) -> None:
|
||||
"""Set the cap on vram cache size."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def max_vram_cache_size(self) -> float:
|
||||
"""Return the maximum size the VRAM cache can grow to."""
|
||||
pass
|
||||
|
||||
@max_vram_cache_size.setter
|
||||
@abstractmethod
|
||||
def max_vram_cache_size(self, value: float) -> float:
|
||||
"""Set the maximum size the VRAM cache can grow to."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def offload_unlocked_models(self, size_required: int) -> None:
|
||||
"""Offload from VRAM any models not actively in use."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
|
||||
"""Move model into the indicated device."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def stats(self) -> Optional[CacheStats]:
|
||||
"""Return collected CacheStats object."""
|
||||
pass
|
||||
|
||||
@stats.setter
|
||||
@abstractmethod
|
||||
def stats(self, stats: CacheStats) -> None:
|
||||
"""Set the CacheStats object for collectin cache statistics."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def logger(self) -> Logger:
|
||||
"""Return the logger used by the cache."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def make_room(self, size: int) -> None:
|
||||
"""Make enough room in the cache to accommodate a new model of indicated size."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def put(
|
||||
self,
|
||||
key: str,
|
||||
model: T,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> None:
|
||||
"""Store model under key and optional submodel_type."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
stats_name: Optional[str] = None,
|
||||
) -> ModelLockerBase:
|
||||
"""
|
||||
Retrieve model using key and optional submodel_type.
|
||||
|
||||
:param key: Opaque model key
|
||||
:param submodel_type: Type of the submodel to fetch
|
||||
:param stats_name: A human-readable id for the model for the purposes of
|
||||
stats reporting.
|
||||
|
||||
This may raise an IndexError if the model is not in the cache.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cache_size(self) -> int:
|
||||
"""Get the total size of the models currently cached."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def print_cuda_stats(self) -> None:
|
||||
"""Log debugging information on CUDA usage."""
|
||||
pass
|
||||
@@ -1,426 +0,0 @@
|
||||
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
|
||||
# TODO: Add Stalker's proper name to copyright
|
||||
""" """
|
||||
|
||||
import gc
|
||||
import math
|
||||
import time
|
||||
from contextlib import suppress
|
||||
from logging import Logger
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel, SubModelType
|
||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import (
|
||||
CacheRecord,
|
||||
CacheStats,
|
||||
ModelCacheBase,
|
||||
ModelLockerBase,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.model_locker import ModelLocker
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
# Size of a GB in bytes.
|
||||
GB = 2**30
|
||||
|
||||
# Size of a MB in bytes.
|
||||
MB = 2**20
|
||||
|
||||
|
||||
class ModelCache(ModelCacheBase[AnyModel]):
|
||||
"""A cache for managing models in memory.
|
||||
|
||||
The cache is based on two levels of model storage:
|
||||
- execution_device: The device where most models are executed (typically "cuda", "mps", or "cpu").
|
||||
- storage_device: The device where models are offloaded when not in active use (typically "cpu").
|
||||
|
||||
The model cache is based on the following assumptions:
|
||||
- storage_device_mem_size > execution_device_mem_size
|
||||
- disk_to_storage_device_transfer_time >> storage_device_to_execution_device_transfer_time
|
||||
|
||||
A copy of all models in the cache is always kept on the storage_device. A subset of the models also have a copy on
|
||||
the execution_device.
|
||||
|
||||
Models are moved between the storage_device and the execution_device as necessary. Cache size limits are enforced
|
||||
on both the storage_device and the execution_device. The execution_device cache uses a smallest-first offload
|
||||
policy. The storage_device cache uses a least-recently-used (LRU) offload policy.
|
||||
|
||||
Note: Neither of these offload policies has really been compared against alternatives. It's likely that different
|
||||
policies would be better, although the optimal policies are likely heavily dependent on usage patterns and HW
|
||||
configuration.
|
||||
|
||||
The cache returns context manager generators designed to load the model into the execution device (often GPU) within
|
||||
the context, and unload outside the context.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
cache = ModelCache(max_cache_size=7.5, max_vram_cache_size=6.0)
|
||||
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1:
|
||||
do_something_on_gpu(SD1)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_cache_size: float,
|
||||
max_vram_cache_size: float,
|
||||
execution_device: torch.device = torch.device("cuda"),
|
||||
storage_device: torch.device = torch.device("cpu"),
|
||||
precision: torch.dtype = torch.float16,
|
||||
lazy_offloading: bool = True,
|
||||
log_memory_usage: bool = False,
|
||||
logger: Optional[Logger] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the model RAM cache.
|
||||
|
||||
:param max_cache_size: Maximum size of the storage_device cache in GBs.
|
||||
:param max_vram_cache_size: Maximum size of the execution_device cache in GBs.
|
||||
:param execution_device: Torch device to load active model into [torch.device('cuda')]
|
||||
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
||||
:param precision: Precision for loaded models [torch.float16]
|
||||
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
||||
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
|
||||
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
|
||||
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
|
||||
behaviour.
|
||||
:param logger: InvokeAILogger to use (otherwise creates one)
|
||||
"""
|
||||
# allow lazy offloading only when vram cache enabled
|
||||
self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0
|
||||
self._max_cache_size: float = max_cache_size
|
||||
self._max_vram_cache_size: float = max_vram_cache_size
|
||||
self._execution_device: torch.device = execution_device
|
||||
self._storage_device: torch.device = storage_device
|
||||
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
|
||||
self._log_memory_usage = log_memory_usage
|
||||
self._stats: Optional[CacheStats] = None
|
||||
|
||||
self._cached_models: Dict[str, CacheRecord[AnyModel]] = {}
|
||||
self._cache_stack: List[str] = []
|
||||
|
||||
@property
|
||||
def logger(self) -> Logger:
|
||||
"""Return the logger used by the cache."""
|
||||
return self._logger
|
||||
|
||||
@property
|
||||
def lazy_offloading(self) -> bool:
|
||||
"""Return true if the cache is configured to lazily offload models in VRAM."""
|
||||
return self._lazy_offloading
|
||||
|
||||
@property
|
||||
def storage_device(self) -> torch.device:
|
||||
"""Return the storage device (e.g. "CPU" for RAM)."""
|
||||
return self._storage_device
|
||||
|
||||
@property
|
||||
def execution_device(self) -> torch.device:
|
||||
"""Return the exection device (e.g. "cuda" for VRAM)."""
|
||||
return self._execution_device
|
||||
|
||||
@property
|
||||
def max_cache_size(self) -> float:
|
||||
"""Return the cap on cache size."""
|
||||
return self._max_cache_size
|
||||
|
||||
@max_cache_size.setter
|
||||
def max_cache_size(self, value: float) -> None:
|
||||
"""Set the cap on cache size."""
|
||||
self._max_cache_size = value
|
||||
|
||||
@property
|
||||
def max_vram_cache_size(self) -> float:
|
||||
"""Return the cap on vram cache size."""
|
||||
return self._max_vram_cache_size
|
||||
|
||||
@max_vram_cache_size.setter
|
||||
def max_vram_cache_size(self, value: float) -> None:
|
||||
"""Set the cap on vram cache size."""
|
||||
self._max_vram_cache_size = value
|
||||
|
||||
@property
|
||||
def stats(self) -> Optional[CacheStats]:
|
||||
"""Return collected CacheStats object."""
|
||||
return self._stats
|
||||
|
||||
@stats.setter
|
||||
def stats(self, stats: CacheStats) -> None:
|
||||
"""Set the CacheStats object for collectin cache statistics."""
|
||||
self._stats = stats
|
||||
|
||||
def cache_size(self) -> int:
|
||||
"""Get the total size of the models currently cached."""
|
||||
total = 0
|
||||
for cache_record in self._cached_models.values():
|
||||
total += cache_record.size
|
||||
return total
|
||||
|
||||
def put(
|
||||
self,
|
||||
key: str,
|
||||
model: AnyModel,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> None:
|
||||
"""Store model under key and optional submodel_type."""
|
||||
key = self._make_cache_key(key, submodel_type)
|
||||
if key in self._cached_models:
|
||||
return
|
||||
size = calc_model_size_by_data(self.logger, model)
|
||||
self.make_room(size)
|
||||
|
||||
running_on_cpu = self.execution_device == torch.device("cpu")
|
||||
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) and not running_on_cpu else None
|
||||
cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size)
|
||||
self._cached_models[key] = cache_record
|
||||
self._cache_stack.append(key)
|
||||
|
||||
def get(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
stats_name: Optional[str] = None,
|
||||
) -> ModelLockerBase:
|
||||
"""
|
||||
Retrieve model using key and optional submodel_type.
|
||||
|
||||
:param key: Opaque model key
|
||||
:param submodel_type: Type of the submodel to fetch
|
||||
:param stats_name: A human-readable id for the model for the purposes of
|
||||
stats reporting.
|
||||
|
||||
This may raise an IndexError if the model is not in the cache.
|
||||
"""
|
||||
key = self._make_cache_key(key, submodel_type)
|
||||
if key in self._cached_models:
|
||||
if self.stats:
|
||||
self.stats.hits += 1
|
||||
else:
|
||||
if self.stats:
|
||||
self.stats.misses += 1
|
||||
raise IndexError(f"The model with key {key} is not in the cache.")
|
||||
|
||||
cache_entry = self._cached_models[key]
|
||||
|
||||
# more stats
|
||||
if self.stats:
|
||||
stats_name = stats_name or key
|
||||
self.stats.cache_size = int(self._max_cache_size * GB)
|
||||
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
|
||||
self.stats.in_cache = len(self._cached_models)
|
||||
self.stats.loaded_model_sizes[stats_name] = max(
|
||||
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size
|
||||
)
|
||||
|
||||
# this moves the entry to the top (right end) of the stack
|
||||
with suppress(Exception):
|
||||
self._cache_stack.remove(key)
|
||||
self._cache_stack.append(key)
|
||||
return ModelLocker(
|
||||
cache=self,
|
||||
cache_entry=cache_entry,
|
||||
)
|
||||
|
||||
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
|
||||
if self._log_memory_usage:
|
||||
return MemorySnapshot.capture()
|
||||
return None
|
||||
|
||||
def _make_cache_key(self, model_key: str, submodel_type: Optional[SubModelType] = None) -> str:
|
||||
if submodel_type:
|
||||
return f"{model_key}:{submodel_type.value}"
|
||||
else:
|
||||
return model_key
|
||||
|
||||
def offload_unlocked_models(self, size_required: int) -> None:
|
||||
"""Offload models from the execution_device to make room for size_required.
|
||||
|
||||
:param size_required: The amount of space to clear in the execution_device cache, in bytes.
|
||||
"""
|
||||
reserved = self._max_vram_cache_size * GB
|
||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||
self.logger.debug(f"{(vram_in_use/GB):.2f}GB VRAM needed for models; max allowed={(reserved/GB):.2f}GB")
|
||||
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
|
||||
if vram_in_use <= reserved:
|
||||
break
|
||||
if not cache_entry.loaded:
|
||||
continue
|
||||
if not cache_entry.locked:
|
||||
self.move_model_to_device(cache_entry, self.storage_device)
|
||||
cache_entry.loaded = False
|
||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||
self.logger.debug(
|
||||
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GB):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GB):.2f}GB"
|
||||
)
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
|
||||
"""Move model into the indicated device.
|
||||
|
||||
:param cache_entry: The CacheRecord for the model
|
||||
:param target_device: The torch.device to move the model into
|
||||
|
||||
May raise a torch.cuda.OutOfMemoryError
|
||||
"""
|
||||
self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
|
||||
source_device = cache_entry.device
|
||||
|
||||
# Note: We compare device types only so that 'cuda' == 'cuda:0'.
|
||||
# This would need to be revised to support multi-GPU.
|
||||
if torch.device(source_device).type == torch.device(target_device).type:
|
||||
return
|
||||
|
||||
# Some models don't have a `to` method, in which case they run in RAM/CPU.
|
||||
if not hasattr(cache_entry.model, "to"):
|
||||
return
|
||||
|
||||
# This roundabout method for moving the model around is done to avoid
|
||||
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
|
||||
# When moving to VRAM, we copy (not move) each element of the state dict from
|
||||
# RAM to a new state dict in VRAM, and then inject it into the model.
|
||||
# This operation is slightly faster than running `to()` on the whole model.
|
||||
#
|
||||
# When the model needs to be removed from VRAM we simply delete the copy
|
||||
# of the state dict in VRAM, and reinject the state dict that is cached
|
||||
# in RAM into the model. So this operation is very fast.
|
||||
start_model_to_time = time.time()
|
||||
snapshot_before = self._capture_memory_snapshot()
|
||||
|
||||
try:
|
||||
if cache_entry.state_dict is not None:
|
||||
assert hasattr(cache_entry.model, "load_state_dict")
|
||||
if target_device == self.storage_device:
|
||||
cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True)
|
||||
else:
|
||||
new_dict: Dict[str, torch.Tensor] = {}
|
||||
for k, v in cache_entry.state_dict.items():
|
||||
new_dict[k] = v.to(target_device, copy=True)
|
||||
cache_entry.model.load_state_dict(new_dict, assign=True)
|
||||
cache_entry.model.to(target_device)
|
||||
cache_entry.device = target_device
|
||||
except Exception as e: # blow away cache entry
|
||||
self._delete_cache_entry(cache_entry)
|
||||
raise e
|
||||
|
||||
snapshot_after = self._capture_memory_snapshot()
|
||||
end_model_to_time = time.time()
|
||||
self.logger.debug(
|
||||
f"Moved model '{cache_entry.key}' from {source_device} to"
|
||||
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
|
||||
f"Estimated model size: {(cache_entry.size/GB):.3f} GB."
|
||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||
)
|
||||
|
||||
if (
|
||||
snapshot_before is not None
|
||||
and snapshot_after is not None
|
||||
and snapshot_before.vram is not None
|
||||
and snapshot_after.vram is not None
|
||||
):
|
||||
vram_change = abs(snapshot_before.vram - snapshot_after.vram)
|
||||
|
||||
# If the estimated model size does not match the change in VRAM, log a warning.
|
||||
if not math.isclose(
|
||||
vram_change,
|
||||
cache_entry.size,
|
||||
rel_tol=0.1,
|
||||
abs_tol=10 * MB,
|
||||
):
|
||||
self.logger.debug(
|
||||
f"Moving model '{cache_entry.key}' from {source_device} to"
|
||||
f" {target_device} caused an unexpected change in VRAM usage. The model's"
|
||||
" estimated size may be incorrect. Estimated model size:"
|
||||
f" {(cache_entry.size/GB):.3f} GB.\n"
|
||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||
)
|
||||
|
||||
def print_cuda_stats(self) -> None:
|
||||
"""Log CUDA diagnostics."""
|
||||
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GB)
|
||||
ram = "%4.2fG" % (self.cache_size() / GB)
|
||||
|
||||
in_ram_models = 0
|
||||
in_vram_models = 0
|
||||
locked_in_vram_models = 0
|
||||
for cache_record in self._cached_models.values():
|
||||
if hasattr(cache_record.model, "device"):
|
||||
if cache_record.model.device == self.storage_device:
|
||||
in_ram_models += 1
|
||||
else:
|
||||
in_vram_models += 1
|
||||
if cache_record.locked:
|
||||
locked_in_vram_models += 1
|
||||
|
||||
self.logger.debug(
|
||||
f"Current VRAM/RAM usage: {vram}/{ram}; models_in_ram/models_in_vram(locked) ="
|
||||
f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})"
|
||||
)
|
||||
|
||||
def make_room(self, size: int) -> None:
|
||||
"""Make enough room in the cache to accommodate a new model of indicated size.
|
||||
|
||||
Note: This function deletes all of the cache's internal references to a model in order to free it. If there are
|
||||
external references to the model, there's nothing that the cache can do about it, and those models will not be
|
||||
garbage-collected.
|
||||
"""
|
||||
bytes_needed = size
|
||||
maximum_size = self.max_cache_size * GB # stored in GB, convert to bytes
|
||||
current_size = self.cache_size()
|
||||
|
||||
if current_size + bytes_needed > maximum_size:
|
||||
self.logger.debug(
|
||||
f"Max cache size exceeded: {(current_size/GB):.2f}/{self.max_cache_size:.2f} GB, need an additional"
|
||||
f" {(bytes_needed/GB):.2f} GB"
|
||||
)
|
||||
|
||||
self.logger.debug(f"Before making_room: cached_models={len(self._cached_models)}")
|
||||
|
||||
pos = 0
|
||||
models_cleared = 0
|
||||
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
|
||||
model_key = self._cache_stack[pos]
|
||||
cache_entry = self._cached_models[model_key]
|
||||
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
|
||||
self.logger.debug(
|
||||
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}"
|
||||
)
|
||||
|
||||
if not cache_entry.locked:
|
||||
self.logger.debug(
|
||||
f"Removing {model_key} from RAM cache to free at least {(size/GB):.2f} GB (-{(cache_entry.size/GB):.2f} GB)"
|
||||
)
|
||||
current_size -= cache_entry.size
|
||||
models_cleared += 1
|
||||
self._delete_cache_entry(cache_entry)
|
||||
del cache_entry
|
||||
|
||||
else:
|
||||
pos += 1
|
||||
|
||||
if models_cleared > 0:
|
||||
# There would likely be some 'garbage' to be collected regardless of whether a model was cleared or not, but
|
||||
# there is a significant time cost to calling `gc.collect()`, so we want to use it sparingly. (The time cost
|
||||
# is high even if no garbage gets collected.)
|
||||
#
|
||||
# Calling gc.collect(...) when a model is cleared seems like a good middle-ground:
|
||||
# - If models had to be cleared, it's a signal that we are close to our memory limit.
|
||||
# - If models were cleared, there's a good chance that there's a significant amount of garbage to be
|
||||
# collected.
|
||||
#
|
||||
# Keep in mind that gc is only responsible for handling reference cycles. Most objects should be cleaned up
|
||||
# immediately when their reference count hits 0.
|
||||
if self.stats:
|
||||
self.stats.cleared = models_cleared
|
||||
gc.collect()
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
self.logger.debug(f"After making room: cached_models={len(self._cached_models)}")
|
||||
|
||||
def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None:
|
||||
self._cache_stack.remove(cache_entry.key)
|
||||
del self._cached_models[cache_entry.key]
|
||||
@@ -1,64 +0,0 @@
|
||||
"""
|
||||
Base class and implementation of a class that moves models in and out of VRAM.
|
||||
"""
|
||||
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import (
|
||||
CacheRecord,
|
||||
ModelCacheBase,
|
||||
ModelLockerBase,
|
||||
)
|
||||
|
||||
|
||||
class ModelLocker(ModelLockerBase):
|
||||
"""Internal class that mediates movement in and out of GPU."""
|
||||
|
||||
def __init__(self, cache: ModelCacheBase[AnyModel], cache_entry: CacheRecord[AnyModel]):
|
||||
"""
|
||||
Initialize the model locker.
|
||||
|
||||
:param cache: The ModelCache object
|
||||
:param cache_entry: The entry in the model cache
|
||||
"""
|
||||
self._cache = cache
|
||||
self._cache_entry = cache_entry
|
||||
|
||||
@property
|
||||
def model(self) -> AnyModel:
|
||||
"""Return the model without moving it around."""
|
||||
return self._cache_entry.model
|
||||
|
||||
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
|
||||
"""Return the state dict (if any) for the cached model."""
|
||||
return self._cache_entry.state_dict
|
||||
|
||||
def lock(self) -> AnyModel:
|
||||
"""Move the model into the execution device (GPU) and lock it."""
|
||||
self._cache_entry.lock()
|
||||
try:
|
||||
if self._cache.lazy_offloading:
|
||||
self._cache.offload_unlocked_models(self._cache_entry.size)
|
||||
self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device)
|
||||
self._cache_entry.loaded = True
|
||||
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}")
|
||||
self._cache.print_cuda_stats()
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
self._cache.logger.warning("Insufficient GPU memory to load model. Aborting")
|
||||
self._cache_entry.unlock()
|
||||
raise
|
||||
except Exception:
|
||||
self._cache_entry.unlock()
|
||||
raise
|
||||
|
||||
return self.model
|
||||
|
||||
def unlock(self) -> None:
|
||||
"""Call upon exit from context."""
|
||||
self._cache_entry.unlock()
|
||||
if not self._cache.lazy_offloading:
|
||||
self._cache.offload_unlocked_models(0)
|
||||
self._cache.print_cuda_stats()
|
||||
@@ -0,0 +1,33 @@
|
||||
from typing import Any, Callable
|
||||
|
||||
import torch
|
||||
from torch.overrides import TorchFunctionMode
|
||||
|
||||
|
||||
def add_autocast_to_module_forward(m: torch.nn.Module, to_device: torch.device):
|
||||
"""Monkey-patch m.forward(...) with a new forward(...) method that activates device autocasting for its duration."""
|
||||
old_forward = m.forward
|
||||
|
||||
def new_forward(*args: Any, **kwargs: Any):
|
||||
with TorchFunctionAutocastDeviceContext(to_device):
|
||||
return old_forward(*args, **kwargs)
|
||||
|
||||
m.forward = new_forward
|
||||
|
||||
|
||||
def _cast_to_device_and_run(
|
||||
func: Callable[..., Any], args: tuple[Any, ...], kwargs: dict[str, Any], to_device: torch.device
|
||||
):
|
||||
args_on_device = [a.to(to_device) if isinstance(a, torch.Tensor) else a for a in args]
|
||||
kwargs_on_device = {k: v.to(to_device) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
|
||||
return func(*args_on_device, **kwargs_on_device)
|
||||
|
||||
|
||||
class TorchFunctionAutocastDeviceContext(TorchFunctionMode):
|
||||
def __init__(self, to_device: torch.device):
|
||||
self._to_device = to_device
|
||||
|
||||
def __torch_function__(
|
||||
self, func: Callable[..., Any], types, args: tuple[Any, ...] = (), kwargs: dict[str, Any] | None = None
|
||||
):
|
||||
return _cast_to_device_and_run(func, args, kwargs or {}, self._to_device)
|
||||
@@ -26,7 +26,7 @@ from invokeai.backend.model_manager import (
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ class LoRALoader(ModelLoader):
|
||||
self,
|
||||
app_config: InvokeAIAppConfig,
|
||||
logger: Logger,
|
||||
ram_cache: ModelCacheBase[AnyModel],
|
||||
ram_cache: ModelCache,
|
||||
):
|
||||
"""Initialize the loader."""
|
||||
super().__init__(app_config, logger, ram_cache)
|
||||
|
||||
@@ -25,6 +25,7 @@ from invokeai.backend.model_manager.config import (
|
||||
DiffusersConfigBase,
|
||||
MainCheckpointConfig,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import get_model_cache_key
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
@@ -132,5 +133,5 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
if subtype == submodel_type:
|
||||
continue
|
||||
if submodel := getattr(pipeline, subtype.value, None):
|
||||
self._ram_cache.put(config.key, submodel_type=subtype, model=submodel)
|
||||
self._ram_cache.put(get_model_cache_key(config.key, subtype), model=submodel)
|
||||
return getattr(pipeline, submodel_type.value)
|
||||
|
||||
12
invokeai/backend/util/prefix_logger_adapter.py
Normal file
12
invokeai/backend/util/prefix_logger_adapter.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import logging
|
||||
from typing import Any, MutableMapping
|
||||
|
||||
|
||||
# Issue with type hints related to LoggerAdapter: https://github.com/python/typeshed/issues/7855
|
||||
class PrefixedLoggerAdapter(logging.LoggerAdapter): # type: ignore
|
||||
def __init__(self, logger: logging.Logger, prefix: str):
|
||||
super().__init__(logger, {})
|
||||
self.prefix = prefix
|
||||
|
||||
def process(self, msg: str, kwargs: MutableMapping[str, Any]) -> tuple[str, MutableMapping[str, Any]]:
|
||||
return f"[{self.prefix}] {msg}", kwargs
|
||||
@@ -96,7 +96,9 @@
|
||||
"new": "Neu",
|
||||
"ok": "OK",
|
||||
"close": "Schließen",
|
||||
"clipboard": "Zwischenablage"
|
||||
"clipboard": "Zwischenablage",
|
||||
"generating": "Generieren",
|
||||
"loadingModel": "Lade Modell"
|
||||
},
|
||||
"gallery": {
|
||||
"galleryImageSize": "Bildgröße",
|
||||
@@ -591,7 +593,15 @@
|
||||
"loraTriggerPhrases": "LoRA-Auslösephrasen",
|
||||
"installingBundle": "Bündel wird installiert",
|
||||
"triggerPhrases": "Auslösephrasen",
|
||||
"mainModelTriggerPhrases": "Hauptmodell-Auslösephrasen"
|
||||
"mainModelTriggerPhrases": "Hauptmodell-Auslösephrasen",
|
||||
"noDefaultSettings": "Für dieses Modell sind keine Standardeinstellungen konfiguriert. Besuchen Sie den Modell-Manager, um Standardeinstellungen hinzuzufügen.",
|
||||
"defaultSettingsOutOfSync": "Einige Einstellungen stimmen nicht mit den Standardeinstellungen des Modells überein:",
|
||||
"clipLEmbed": "CLIP-L einbetten",
|
||||
"clipGEmbed": "CLIP-G einbetten",
|
||||
"hfTokenLabel": "HuggingFace-Token (für einige Modelle erforderlich)",
|
||||
"hfTokenHelperText": "Für die Nutzung einiger Modelle ist ein HF-Token erforderlich. Klicken Sie hier, um Ihr Token zu erstellen oder zu erhalten.",
|
||||
"hfForbidden": "Sie haben keinen Zugriff auf dieses HF-Modell",
|
||||
"hfTokenInvalid": "Ungültiges oder fehlendes HF-Token"
|
||||
},
|
||||
"parameters": {
|
||||
"images": "Bilder",
|
||||
@@ -841,7 +851,8 @@
|
||||
"upscaling": "Hochskalierung",
|
||||
"canvas": "Leinwand",
|
||||
"prompts_one": "Prompt",
|
||||
"prompts_other": "Prompts"
|
||||
"prompts_other": "Prompts",
|
||||
"batchSize": "Stapelgröße"
|
||||
},
|
||||
"metadata": {
|
||||
"negativePrompt": "Negativ Beschreibung",
|
||||
@@ -1081,6 +1092,21 @@
|
||||
},
|
||||
"patchmatchDownScaleSize": {
|
||||
"heading": "Herunterskalieren"
|
||||
},
|
||||
"paramHeight": {
|
||||
"heading": "Höhe",
|
||||
"paragraphs": [
|
||||
"Höhe des generierten Bildes. Muss ein Vielfaches von 8 sein."
|
||||
]
|
||||
},
|
||||
"paramUpscaleMethod": {
|
||||
"heading": "Vergrößerungsmethode",
|
||||
"paragraphs": [
|
||||
"Methode zum Hochskalieren des Bildes für High Resolution Fix."
|
||||
]
|
||||
},
|
||||
"paramHrf": {
|
||||
"heading": "High Resolution Fix aktivieren"
|
||||
}
|
||||
},
|
||||
"invocationCache": {
|
||||
|
||||
@@ -176,7 +176,8 @@
|
||||
"reset": "Reset",
|
||||
"none": "None",
|
||||
"new": "New",
|
||||
"generating": "Generating"
|
||||
"generating": "Generating",
|
||||
"warnings": "Warnings"
|
||||
},
|
||||
"hrf": {
|
||||
"hrf": "High Resolution Fix",
|
||||
@@ -1038,20 +1039,7 @@
|
||||
"canvasIsSelectingObject": "Canvas is busy (selecting object)",
|
||||
"noPrompts": "No prompts generated",
|
||||
"noNodesInGraph": "No nodes in graph",
|
||||
"systemDisconnected": "System disconnected",
|
||||
"layer": {
|
||||
"controlAdapterNoModelSelected": "no Control Adapter model selected",
|
||||
"controlAdapterIncompatibleBaseModel": "incompatible Control Adapter base model",
|
||||
"t2iAdapterIncompatibleBboxWidth": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}}, bbox width is {{width}}",
|
||||
"t2iAdapterIncompatibleBboxHeight": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}}, bbox height is {{height}}",
|
||||
"t2iAdapterIncompatibleScaledBboxWidth": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}}, scaled bbox width is {{width}}",
|
||||
"t2iAdapterIncompatibleScaledBboxHeight": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}}, scaled bbox height is {{height}}",
|
||||
"ipAdapterNoModelSelected": "no IP adapter selected",
|
||||
"ipAdapterIncompatibleBaseModel": "incompatible IP Adapter base model",
|
||||
"ipAdapterNoImageSelected": "no IP Adapter image selected",
|
||||
"rgNoPromptsOrIPAdapters": "no text prompts or IP Adapters",
|
||||
"rgNoRegion": "no region selected"
|
||||
}
|
||||
"systemDisconnected": "System disconnected"
|
||||
},
|
||||
"maskBlur": "Mask Blur",
|
||||
"negativePromptPlaceholder": "Negative Prompt",
|
||||
@@ -1713,6 +1701,8 @@
|
||||
"controlLayer": "Control Layer",
|
||||
"inpaintMask": "Inpaint Mask",
|
||||
"regionalGuidance": "Regional Guidance",
|
||||
"referenceImageRegional": "Reference Image (Regional)",
|
||||
"referenceImageGlobal": "Reference Image (Global)",
|
||||
"asRasterLayer": "As $t(controlLayers.rasterLayer)",
|
||||
"asRasterLayerResize": "As $t(controlLayers.rasterLayer) (Resize)",
|
||||
"asControlLayer": "As $t(controlLayers.controlLayer)",
|
||||
@@ -1798,6 +1788,21 @@
|
||||
"replaceCurrent": "Replace Current",
|
||||
"controlLayerEmptyState": "<UploadButton>Upload an image</UploadButton>, drag an image from the <GalleryButton>gallery</GalleryButton> onto this layer, or draw on the canvas to get started.",
|
||||
"referenceImageEmptyState": "<UploadButton>Upload an image</UploadButton> or drag an image from the <GalleryButton>gallery</GalleryButton> onto this layer to get started.",
|
||||
"warnings": {
|
||||
"problemsFound": "Problems found",
|
||||
"unsupportedModel": "layer not supported for selected base model",
|
||||
"controlAdapterNoModelSelected": "no Control Layer model selected",
|
||||
"controlAdapterIncompatibleBaseModel": "incompatible Control Layer base model",
|
||||
"controlAdapterNoControl": "no control selected/drawn",
|
||||
"ipAdapterNoModelSelected": "no Reference Image model selected",
|
||||
"ipAdapterIncompatibleBaseModel": "incompatible Reference Image base model",
|
||||
"ipAdapterNoImageSelected": "no Reference Image image selected",
|
||||
"rgNoPromptsOrIPAdapters": "no text prompts or Reference Images",
|
||||
"rgNegativePromptNotSupported": "Negative Prompt not supported for selected base model",
|
||||
"rgReferenceImagesNotSupported": "regional Reference Images not supported for selected base model",
|
||||
"rgAutoNegativeNotSupported": "Auto-Negative not supported for selected base model",
|
||||
"rgNoRegion": "no region drawn"
|
||||
},
|
||||
"controlMode": {
|
||||
"controlMode": "Control Mode",
|
||||
"balanced": "Balanced (recommended)",
|
||||
|
||||
@@ -327,7 +327,6 @@
|
||||
"t2iAdapterIncompatibleBboxHeight": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}}, la hauteur de la bounding box est {{height}}",
|
||||
"t2iAdapterIncompatibleBboxWidth": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}}, la largeur de la bounding box est {{width}}",
|
||||
"ipAdapterIncompatibleBaseModel": "modèle de base d'IP adapter incompatible",
|
||||
"rgNoRegion": "aucune zone sélectionnée",
|
||||
"controlAdapterNoModelSelected": "aucun modèle de Control Adapter sélectionné"
|
||||
},
|
||||
"noPrompts": "Aucun prompts généré",
|
||||
|
||||
@@ -96,7 +96,8 @@
|
||||
"clipboard": "Appunti",
|
||||
"ok": "Ok",
|
||||
"generating": "Generazione",
|
||||
"loadingModel": "Caricamento del modello"
|
||||
"loadingModel": "Caricamento del modello",
|
||||
"warnings": "Avvisi"
|
||||
},
|
||||
"gallery": {
|
||||
"galleryImageSize": "Dimensione dell'immagine",
|
||||
@@ -671,11 +672,15 @@
|
||||
"ipAdapterIncompatibleBaseModel": "Il modello base dell'adattatore IP non è compatibile",
|
||||
"ipAdapterNoImageSelected": "Nessuna immagine dell'adattatore IP selezionata",
|
||||
"rgNoPromptsOrIPAdapters": "Nessun prompt o adattatore IP",
|
||||
"rgNoRegion": "Nessuna regione selezionata",
|
||||
"t2iAdapterIncompatibleBboxWidth": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}}, larghezza riquadro è {{width}}",
|
||||
"t2iAdapterIncompatibleBboxHeight": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}}, altezza riquadro è {{height}}",
|
||||
"t2iAdapterIncompatibleScaledBboxWidth": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}}, larghezza del riquadro scalato {{width}}",
|
||||
"t2iAdapterIncompatibleScaledBboxHeight": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}}, altezza del riquadro scalato {{height}}"
|
||||
"t2iAdapterIncompatibleScaledBboxHeight": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}}, altezza del riquadro scalato {{height}}",
|
||||
"rgNegativePromptNotSupported": "prompt negativo non supportato per il modello base selezionato",
|
||||
"rgAutoNegativeNotSupported": "auto-negativo non supportato per il modello base selezionato",
|
||||
"emptyLayer": "livello vuoto",
|
||||
"unsupportedModel": "livello non supportato per il modello base selezionato",
|
||||
"rgReferenceImagesNotSupported": "immagini di riferimento regionali non supportate per il modello base selezionato"
|
||||
},
|
||||
"fluxModelIncompatibleBboxHeight": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), altezza riquadro è {{height}}",
|
||||
"fluxModelIncompatibleBboxWidth": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), larghezza riquadro è {{width}}",
|
||||
@@ -687,7 +692,11 @@
|
||||
"canvasIsTransforming": "La tela sta trasformando",
|
||||
"canvasIsRasterizing": "La tela sta rasterizzando",
|
||||
"canvasIsCompositing": "La tela è in fase di composizione",
|
||||
"canvasIsFiltering": "La tela sta filtrando"
|
||||
"canvasIsFiltering": "La tela sta filtrando",
|
||||
"collectionTooManyItems": "{{nodeLabel}} -> {{fieldLabel}}: troppi elementi, massimo {{maxItems}}",
|
||||
"canvasIsSelectingObject": "La tela è occupata (selezione dell'oggetto)",
|
||||
"collectionTooFewItems": "{{nodeLabel}} -> {{fieldLabel}}: troppi pochi elementi, minimo {{minItems}}",
|
||||
"collectionEmpty": "{{nodeLabel}} -> {{fieldLabel}} raccolta vuota"
|
||||
},
|
||||
"useCpuNoise": "Usa la CPU per generare rumore",
|
||||
"iterations": "Iterazioni",
|
||||
@@ -972,7 +981,9 @@
|
||||
"saveToGallery": "Salva nella Galleria",
|
||||
"noMatchingWorkflows": "Nessun flusso di lavoro corrispondente",
|
||||
"noWorkflows": "Nessun flusso di lavoro",
|
||||
"workflowHelpText": "Hai bisogno di aiuto? Consulta la nostra guida <LinkComponent>Introduzione ai flussi di lavoro</LinkComponent>."
|
||||
"workflowHelpText": "Hai bisogno di aiuto? Consulta la nostra guida <LinkComponent>Introduzione ai flussi di lavoro</LinkComponent>.",
|
||||
"specialDesc": "Questa invocazione comporta una gestione speciale nell'applicazione. Ad esempio, i nodi Lotto vengono utilizzati per mettere in coda più grafici da un singolo flusso di lavoro.",
|
||||
"internalDesc": "Questa invocazione è utilizzata internamente da Invoke. Potrebbe subire modifiche significative durante gli aggiornamenti dell'app e potrebbe essere rimossa in qualsiasi momento."
|
||||
},
|
||||
"boards": {
|
||||
"autoAddBoard": "Aggiungi automaticamente bacheca",
|
||||
@@ -1093,7 +1104,8 @@
|
||||
"workflows": "Flussi di lavoro",
|
||||
"generation": "Generazione",
|
||||
"other": "Altro",
|
||||
"gallery": "Galleria"
|
||||
"gallery": "Galleria",
|
||||
"batchSize": "Dimensione del lotto"
|
||||
},
|
||||
"models": {
|
||||
"noMatchingModels": "Nessun modello corrispondente",
|
||||
@@ -1196,7 +1208,8 @@
|
||||
"heading": "Percentuale passi Inizio / Fine",
|
||||
"paragraphs": [
|
||||
"La parte del processo di rimozione del rumore in cui verrà applicato l'adattatore di controllo.",
|
||||
"In genere, gli adattatori di controllo applicati all'inizio del processo guidano la composizione, mentre quelli applicati alla fine guidano i dettagli."
|
||||
"In genere, gli adattatori di controllo applicati all'inizio del processo guidano la composizione, mentre quelli applicati alla fine guidano i dettagli.",
|
||||
"• Passo finale (%): specifica quando interrompere l'applicazione della guida di questo livello e ripristinare la guida generale dal modello e altre impostazioni."
|
||||
]
|
||||
},
|
||||
"noiseUseCPU": {
|
||||
@@ -1300,7 +1313,9 @@
|
||||
"controlNetWeight": {
|
||||
"heading": "Peso",
|
||||
"paragraphs": [
|
||||
"Peso dell'adattatore di controllo. Un peso maggiore porterà a impatti maggiori sull'immagine finale."
|
||||
"Regola la forza con cui il livello influenza il processo di generazione",
|
||||
"• Peso maggiore (0.75-2): crea un impatto più significativo sul risultato finale.",
|
||||
"• Peso inferiore (0-0.75): crea un impatto minore sul risultato finale."
|
||||
]
|
||||
},
|
||||
"paramCFGScale": {
|
||||
@@ -1801,7 +1816,10 @@
|
||||
"full": "Stile e Composizione",
|
||||
"style": "Solo Stile",
|
||||
"composition": "Solo Composizione",
|
||||
"ipAdapterMethod": "Metodo Adattatore IP"
|
||||
"ipAdapterMethod": "Metodo Adattatore IP",
|
||||
"fullDesc": "Applica lo stile visivo (colori, texture) e la composizione (disposizione, struttura).",
|
||||
"styleDesc": "Applica lo stile visivo (colori, texture) senza considerare la disposizione.",
|
||||
"compositionDesc": "Replica disposizione e struttura ignorando lo stile di riferimento."
|
||||
},
|
||||
"showingType": "Mostra {{type}}",
|
||||
"dynamicGrid": "Griglia dinamica",
|
||||
@@ -2044,7 +2062,16 @@
|
||||
"replaceCurrent": "Sostituisci corrente",
|
||||
"mergeDown": "Unire in basso",
|
||||
"mergingLayers": "Unione dei livelli",
|
||||
"controlLayerEmptyState": "<UploadButton>Carica un'immagine</UploadButton>, trascina un'immagine dalla <GalleryButton>galleria</GalleryButton> su questo livello oppure disegna sulla tela per iniziare."
|
||||
"controlLayerEmptyState": "<UploadButton>Carica un'immagine</UploadButton>, trascina un'immagine dalla <GalleryButton>galleria</GalleryButton> su questo livello oppure disegna sulla tela per iniziare.",
|
||||
"useImage": "Usa immagine",
|
||||
"resetGenerationSettings": "Ripristina impostazioni di generazione",
|
||||
"referenceImageEmptyState": "Per iniziare, <UploadButton>carica un'immagine</UploadButton> oppure trascina un'immagine dalla <GalleryButton>galleria</GalleryButton> su questo livello.",
|
||||
"asRasterLayer": "Come $t(controlLayers.rasterLayer)",
|
||||
"asRasterLayerResize": "Come $t(controlLayers.rasterLayer) (Ridimensiona)",
|
||||
"asControlLayer": "Come $t(controlLayers.controlLayer)",
|
||||
"asControlLayerResize": "Come $t(controlLayers.controlLayer) (Ridimensiona)",
|
||||
"newSession": "Nuova sessione",
|
||||
"resetCanvasLayers": "Ripristina livelli Tela"
|
||||
},
|
||||
"ui": {
|
||||
"tabs": {
|
||||
@@ -2144,7 +2171,7 @@
|
||||
"watchRecentReleaseVideos": "Guarda i video su questa versione",
|
||||
"watchUiUpdatesOverview": "Guarda le novità dell'interfaccia",
|
||||
"items": [
|
||||
"<StrongComponent>SD 3.5</StrongComponent>: supporto per SD 3.5 Medium e Large.",
|
||||
"<StrongComponent>Flussi di lavoro</StrongComponent>: esegui un flusso di lavoro per una raccolta di immagini utilizzando il nuovo nodo <StrongComponent>Lotto di immagini</StrongComponent>.",
|
||||
"<StrongComponent>Tela</StrongComponent>: elaborazione semplificata del livello di controllo e impostazioni di controllo predefinite migliorate."
|
||||
]
|
||||
},
|
||||
@@ -2172,5 +2199,67 @@
|
||||
"logNamespaces": "Elementi del registro"
|
||||
},
|
||||
"enableLogging": "Abilita la registrazione"
|
||||
},
|
||||
"supportVideos": {
|
||||
"gettingStarted": "Iniziare",
|
||||
"supportVideos": "Video di supporto",
|
||||
"videos": {
|
||||
"usingControlLayersAndReferenceGuides": {
|
||||
"title": "Utilizzo di livelli di controllo e guide di riferimento",
|
||||
"description": "Scopri come guidare la creazione delle tue immagini con livelli di controllo e immagini di riferimento."
|
||||
},
|
||||
"creatingYourFirstImage": {
|
||||
"description": "Introduzione alla creazione di un'immagine da zero utilizzando gli strumenti di Invoke.",
|
||||
"title": "Creazione della tua prima immagine"
|
||||
},
|
||||
"understandingImageToImageAndDenoising": {
|
||||
"description": "Panoramica delle trasformazioni immagine-a-immagine e della riduzione del rumore in Invoke.",
|
||||
"title": "Comprendere immagine-a-immagine e riduzione del rumore"
|
||||
},
|
||||
"howDoIDoImageToImageTransformation": {
|
||||
"description": "Tutorial su come eseguire trasformazioni da immagine a immagine in Invoke.",
|
||||
"title": "Come si esegue la trasformazione da immagine-a-immagine?"
|
||||
},
|
||||
"howDoIUseInpaintMasks": {
|
||||
"title": "Come si usano le maschere Inpaint?",
|
||||
"description": "Come applicare maschere inpaint per la correzione e la variazione delle immagini."
|
||||
},
|
||||
"howDoIOutpaint": {
|
||||
"description": "Guida all'outpainting oltre i confini dell'immagine originale.",
|
||||
"title": "Come posso eseguire l'outpainting?"
|
||||
},
|
||||
"exploringAIModelsAndConceptAdapters": {
|
||||
"description": "Approfondisci i modelli di intelligenza artificiale e scopri come utilizzare gli adattatori concettuali per il controllo creativo.",
|
||||
"title": "Esplorazione dei modelli di IA e degli adattatori concettuali"
|
||||
},
|
||||
"upscaling": {
|
||||
"title": "Ampliamento",
|
||||
"description": "Come ampliare le immagini con gli strumenti di Invoke per migliorarne la risoluzione."
|
||||
},
|
||||
"creatingAndComposingOnInvokesControlCanvas": {
|
||||
"description": "Impara a comporre immagini utilizzando la tela di controllo di Invoke.",
|
||||
"title": "Creare e comporre sulla tela di controllo di Invoke"
|
||||
},
|
||||
"howDoIGenerateAndSaveToTheGallery": {
|
||||
"description": "Passaggi per generare e salvare le immagini nella galleria.",
|
||||
"title": "Come posso generare e salvare nella Galleria?"
|
||||
},
|
||||
"howDoIEditOnTheCanvas": {
|
||||
"title": "Come posso apportare modifiche sulla tela?",
|
||||
"description": "Guida alla modifica delle immagini direttamente sulla tela."
|
||||
},
|
||||
"howDoIUseControlNetsAndControlLayers": {
|
||||
"title": "Come posso utilizzare le Reti di Controllo e i Livelli di Controllo?",
|
||||
"description": "Impara ad applicare livelli di controllo e reti di controllo alle tue immagini."
|
||||
},
|
||||
"howDoIUseGlobalIPAdaptersAndReferenceImages": {
|
||||
"title": "Come si utilizzano gli adattatori IP globali e le immagini di riferimento?",
|
||||
"description": "Introduzione all'aggiunta di immagini di riferimento e adattatori IP globali."
|
||||
}
|
||||
},
|
||||
"controlCanvas": "Tela di Controllo",
|
||||
"watch": "Guarda",
|
||||
"studioSessionsDesc1": "Dai un'occhiata a <StudioSessionsPlaylistLink /> per approfondimenti su Invoke.",
|
||||
"studioSessionsDesc2": "Unisciti al nostro <DiscordLink /> per partecipare alle sessioni live e fare domande. Le sessioni vengono caricate sulla playlist la settimana successiva."
|
||||
}
|
||||
}
|
||||
|
||||
@@ -236,7 +236,6 @@
|
||||
"controlAdapterIncompatibleBaseModel": "niet-compatibele basismodel voor controle-adapter",
|
||||
"ipAdapterIncompatibleBaseModel": "niet-compatibele basismodel voor IP-adapter",
|
||||
"ipAdapterNoImageSelected": "geen afbeelding voor IP-adapter geselecteerd",
|
||||
"rgNoRegion": "geen gebied geselecteerd",
|
||||
"rgNoPromptsOrIPAdapters": "geen tekstprompts of IP-adapters",
|
||||
"ipAdapterNoModelSelected": "geen IP-adapter geselecteerd"
|
||||
}
|
||||
|
||||
@@ -10,7 +10,24 @@
|
||||
"load": "Załaduj",
|
||||
"statusDisconnected": "Odłączono od serwera",
|
||||
"githubLabel": "GitHub",
|
||||
"discordLabel": "Discord"
|
||||
"discordLabel": "Discord",
|
||||
"clipboard": "Schowek",
|
||||
"aboutDesc": "Wykorzystujesz Invoke do pracy? Sprawdź:",
|
||||
"ai": "SI",
|
||||
"areYouSure": "Czy jesteś pewien?",
|
||||
"copyError": "$t(gallery.copy) Błąd",
|
||||
"apply": "Zastosuj",
|
||||
"copy": "Kopiuj",
|
||||
"or": "albo",
|
||||
"add": "Dodaj",
|
||||
"off": "Wyłączony",
|
||||
"accept": "Zaakceptuj",
|
||||
"cancel": "Anuluj",
|
||||
"advanced": "Zawansowane",
|
||||
"back": "Do tyłu",
|
||||
"auto": "Automatyczny",
|
||||
"beta": "Beta",
|
||||
"close": "Wyjdź"
|
||||
},
|
||||
"gallery": {
|
||||
"galleryImageSize": "Rozmiar obrazów",
|
||||
@@ -65,6 +82,42 @@
|
||||
"uploadImage": "Wgrywanie obrazu",
|
||||
"previousImage": "Poprzedni obraz",
|
||||
"nextImage": "Następny obraz",
|
||||
"menu": "Menu"
|
||||
"menu": "Menu",
|
||||
"mode": "Tryb"
|
||||
},
|
||||
"boards": {
|
||||
"cancel": "Anuluj",
|
||||
"noBoards": "Brak tablic typu {{boardType}}",
|
||||
"imagesWithCount_one": "{{count}} zdjęcie",
|
||||
"imagesWithCount_few": "{{count}} zdjęcia",
|
||||
"imagesWithCount_many": "{{count}} zdjęcia",
|
||||
"private": "Prywatne tablice",
|
||||
"updateBoardError": "Błąd aktualizacji tablicy",
|
||||
"uncategorized": "Nieskategoryzowane",
|
||||
"selectBoard": "Wybierz tablicę",
|
||||
"downloadBoard": "Pobierz tablice",
|
||||
"loading": "Ładowanie...",
|
||||
"move": "Przenieś",
|
||||
"noMatching": "Brak pasujących tablic"
|
||||
},
|
||||
"accordions": {
|
||||
"compositing": {
|
||||
"title": "Kompozycja",
|
||||
"infillTab": "Inskrypcja",
|
||||
"coherenceTab": "Przebieg Koherencji"
|
||||
},
|
||||
"generation": {
|
||||
"title": "Generowanie"
|
||||
},
|
||||
"image": {
|
||||
"title": "Zdjęcie"
|
||||
},
|
||||
"advanced": {
|
||||
"options": "$t(accordions.advanced.title) Opcje",
|
||||
"title": "Zaawansowane"
|
||||
},
|
||||
"control": {
|
||||
"title": "Kontrola"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -652,7 +652,6 @@
|
||||
"ipAdapterNoModelSelected": "IP адаптер не выбран",
|
||||
"controlAdapterNoModelSelected": "не выбрана модель адаптера контроля",
|
||||
"controlAdapterIncompatibleBaseModel": "несовместимая базовая модель адаптера контроля",
|
||||
"rgNoRegion": "регион не выбран",
|
||||
"rgNoPromptsOrIPAdapters": "нет текстовых запросов или IP-адаптеров",
|
||||
"ipAdapterIncompatibleBaseModel": "несовместимая базовая модель IP-адаптера",
|
||||
"ipAdapterNoImageSelected": "изображение IP-адаптера не выбрано",
|
||||
|
||||
@@ -217,7 +217,10 @@
|
||||
"direction": "Phương Hướng",
|
||||
"unknownError": "Lỗi Không Rõ",
|
||||
"selected": "Đã chọn",
|
||||
"tab": "Tab"
|
||||
"tab": "Tab",
|
||||
"loadingModel": "Đang Tải Model",
|
||||
"generating": "Đang Tạo Sinh",
|
||||
"warnings": "Cảnh Báo"
|
||||
},
|
||||
"prompt": {
|
||||
"addPromptTrigger": "Thêm Prompt Trigger",
|
||||
@@ -290,7 +293,8 @@
|
||||
"cancelSucceeded": "Mục Đã Huỷ Bỏ",
|
||||
"completedIn": "Hoàn tất trong",
|
||||
"graphQueued": "Đồ Thị Đã Vào Hàng",
|
||||
"batchQueuedDesc_other": "Thêm {{count}} phiên vào {{direction}} của hàng"
|
||||
"batchQueuedDesc_other": "Thêm {{count}} phiên vào {{direction}} của hàng",
|
||||
"batchSize": "Kích Thước Vùng Hàng Loạt"
|
||||
},
|
||||
"hotkeys": {
|
||||
"canvas": {
|
||||
@@ -733,7 +737,9 @@
|
||||
"textualInversions": "Bộ Đảo Ngược Văn Bản",
|
||||
"loraTriggerPhrases": "Từ Ngữ Kích Hoạt Cho LoRA",
|
||||
"width": "Chiều Rộng",
|
||||
"starterModelsInModelManager": "Model khởi đầu có thể tìm thấy ở Trình Quản Lý Model"
|
||||
"starterModelsInModelManager": "Model khởi đầu có thể tìm thấy ở Trình Quản Lý Model",
|
||||
"clipLEmbed": "CLIP-L Embed",
|
||||
"clipGEmbed": "CLIP-G Embed"
|
||||
},
|
||||
"metadata": {
|
||||
"guidance": "Hướng Dẫn",
|
||||
@@ -905,7 +911,7 @@
|
||||
"unknownNode": "Node Không Rõ",
|
||||
"unknownNodeType": "Loại Node Không Rõ",
|
||||
"unknownTemplate": "Mẫu Trình Bày Không Rõ",
|
||||
"cannotConnectOutputToOutput": "Không thế kết nối đầu ra với đầu vào",
|
||||
"cannotConnectOutputToOutput": "Không thế kết nối đầu ra với đầu ra",
|
||||
"cannotConnectToSelf": "Không thể kết nối với chính nó",
|
||||
"workflow": "Workflow",
|
||||
"addNodeToolTip": "Thêm Node (Shift+A, Space)",
|
||||
@@ -952,7 +958,9 @@
|
||||
"executionStateInProgress": "Đang Xử Lý",
|
||||
"showLegendNodes": "Hiển Thị Vùng Nhập",
|
||||
"outputFieldTypeParseError": "Không thể phân tích loại dữ liệu đầu ra của {{node}}.{{field}} ({{message}})",
|
||||
"modelAccessError": "Không thể tìm thấy model {{key}}, chuyển về mặc định"
|
||||
"modelAccessError": "Không thể tìm thấy model {{key}}, chuyển về mặc định",
|
||||
"internalDesc": "Trình kích hoạt này được dùng bên trong bởi Invoke. Nó có thể phá hỏng thay đổi trong khi cập nhật ứng dụng và có thể bị xoá bất cứ lúc nào.",
|
||||
"specialDesc": "Trình kích hoạt này có một số xử lý đặc biệt trong ứng dụng. Ví dụ, Node Hàng Loạt được dùng để xếp vào nhiều đồ thị từ một workflow."
|
||||
},
|
||||
"popovers": {
|
||||
"paramCFGRescaleMultiplier": {
|
||||
@@ -1105,7 +1113,9 @@
|
||||
},
|
||||
"controlNetWeight": {
|
||||
"paragraphs": [
|
||||
"Trọng lượng của Control Adapter. Trọng lượng càng cao sẽ dẫn đến tác động càng lớn lên ảnh cuối cùng."
|
||||
"Điều chỉnh mức độ layer ảnh hưởng đến quá trình xử lý tạo sinh.",
|
||||
"• Trọng Lượng Lớn Hơn (.75-2): Gây ra ảnh hưởng lớn hơn lên kết quả cuối cùng.",
|
||||
"• Trọng Lượng Nhỏ Hơn (0-.75): Gây ra ảnh hưởng nhỏ hơn lên kết quả cuối cùng."
|
||||
],
|
||||
"heading": "Trọng Lượng"
|
||||
},
|
||||
@@ -1149,7 +1159,7 @@
|
||||
},
|
||||
"ipAdapterMethod": {
|
||||
"paragraphs": [
|
||||
"Cách thức dùng để áp dụng IP Adapter hiện tại."
|
||||
"Phương thức định nghĩa cách ảnh mẫu sẽ chỉ dẫn quá trình xử lý tạo sinh."
|
||||
],
|
||||
"heading": "Cách Thức"
|
||||
},
|
||||
@@ -1196,8 +1206,9 @@
|
||||
},
|
||||
"controlNetBeginEnd": {
|
||||
"paragraphs": [
|
||||
"Một phần trong quá trình xử lý khử nhiễu mà sẽ được Control Adapter áp dụng.",
|
||||
"Nói chung, Control Adapter áp dụng vào lúc bắt đầu của quá trình hướng dẫn thành phần, và cũng áp dụng vào lúc kết thúc hướng dẫn chi tiết."
|
||||
"Cài đặt này xác định phần xử lý khử nhiễu (trong khi tạo sinh) kết hợp với chỉ dẫn từ layer này.",
|
||||
"• Bước Bắt Đầu (%): Chỉ định lúc bắt đầu áp dụng chỉ dẫn từ layer này trong quá trình tạo sinh.",
|
||||
"• Bước Kết Thúc (%): Chỉ định lúc dừng áp dụng chỉ dẫn của layer này và trở về chỉ dẫn chung từ model và các thiết lập khác."
|
||||
],
|
||||
"heading": "Phần Trăm Tham Số Bước Khi Bắt Đầu/Kết Thúc"
|
||||
},
|
||||
@@ -1401,7 +1412,6 @@
|
||||
"invoke": {
|
||||
"layer": {
|
||||
"t2iAdapterIncompatibleScaledBboxHeight": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}}, tỉ lệ chiều dài hộp giới hạn là {{height}}",
|
||||
"rgNoRegion": "không có vùng được chọn",
|
||||
"ipAdapterNoModelSelected": "không có IP Adapter được lựa chọn",
|
||||
"ipAdapterNoImageSelected": "không có ảnh IP Adapter được lựa chọn",
|
||||
"t2iAdapterIncompatibleBboxHeight": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}}, chiều dài hộp giới hạn là {{height}}",
|
||||
@@ -1410,15 +1420,20 @@
|
||||
"rgNoPromptsOrIPAdapters": "không có lệnh chữ hoặc IP Adapter",
|
||||
"controlAdapterIncompatibleBaseModel": "model cơ sở của Control Adapter không tương thích",
|
||||
"ipAdapterIncompatibleBaseModel": "dạng model cơ sở của IP Adapter không tương thích",
|
||||
"controlAdapterNoModelSelected": "không có model Control Adapter được chọn"
|
||||
"controlAdapterNoModelSelected": "không có model Control Adapter được chọn",
|
||||
"emptyLayer": "layer trống",
|
||||
"rgAutoNegativeNotSupported": "trình tự động đảo chiều không được hỗ trợ cho model cơ sở đang dùng",
|
||||
"rgNegativePromptNotSupported": "lệnh tiêu cực không được hỗ trợ cho model cơ sở đang dùng",
|
||||
"unsupportedModel": "layer không được hỗ trợ cho model cơ sở đang dùng",
|
||||
"rgReferenceImagesNotSupported": "ảnh mẫu khu vực không được hỗ trợ cho model cơ sở đang dùng"
|
||||
},
|
||||
"fluxModelIncompatibleBboxWidth": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), chiều rộng hộp giới hạn là {{width}}",
|
||||
"noModelSelected": "Không có model được lựa chọn",
|
||||
"fluxModelIncompatibleScaledBboxHeight": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), tỉ lệ chiều dài hộp giới hạn là {{height}}",
|
||||
"canvasIsFiltering": "Canvas đang được lọc",
|
||||
"canvasIsRasterizing": "Canvas đang được raster hoá",
|
||||
"canvasIsTransforming": "Canvas đang được biến đổi",
|
||||
"canvasIsCompositing": "Canvas đang được kết hợp",
|
||||
"canvasIsFiltering": "Canvas đang bận (đang lọc)",
|
||||
"canvasIsRasterizing": "Canvas đang bận (đang raster hoá)",
|
||||
"canvasIsTransforming": "Canvas đang bận (đang biến đổi)",
|
||||
"canvasIsCompositing": "Canvas đang bận (đang kết hợp)",
|
||||
"noPrompts": "Không có lệnh được tạo",
|
||||
"noNodesInGraph": "Không có node trong đồ thị",
|
||||
"addingImagesTo": "Thêm ảnh vào",
|
||||
@@ -1430,8 +1445,12 @@
|
||||
"missingNodeTemplate": "Thiếu mẫu trình bày node",
|
||||
"fluxModelIncompatibleBboxHeight": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), chiều dài hộp giới hạn là {{height}}",
|
||||
"fluxModelIncompatibleScaledBboxWidth": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), tỉ lệ chiều rộng hộp giới hạn là {{width}}",
|
||||
"missingInputForField": "{{nodeLabel}} -> {{fieldLabel}} thiếu đầu ra",
|
||||
"missingFieldTemplate": "Thiếu vùng mẫu trình bày"
|
||||
"missingInputForField": "{{nodeLabel}} -> {{fieldLabel}}: thiếu đầu vào",
|
||||
"missingFieldTemplate": "Thiếu vùng mẫu trình bày",
|
||||
"collectionEmpty": "{{nodeLabel}} -> {{fieldLabel}} tài nguyên trống",
|
||||
"collectionTooFewItems": "{{nodeLabel}} -> {{fieldLabel}}: quá ít mục, tối thiểu {{minItems}}",
|
||||
"collectionTooManyItems": "{{nodeLabel}} -> {{fieldLabel}}: quá nhiều mục, tối đa {{maxItems}}",
|
||||
"canvasIsSelectingObject": "Canvas đang bận (đang chọn đồ vật)"
|
||||
},
|
||||
"cfgScale": "Thước Đo CFG",
|
||||
"useSeed": "Dùng Tham Số Hạt Giống",
|
||||
@@ -1542,7 +1561,8 @@
|
||||
"resetWebUIDesc2": "Nếu ảnh không được xuất hiện trong thư viện hoặc điều gì đó không ổn đang diễn ra, hãy thử khởi động lại trước khi báo lỗi trên Github.",
|
||||
"displayInProgress": "Hiển Thị Hình Ảnh Đang Xử Lý",
|
||||
"intermediatesClearedFailed": "Có Vấn Đề Khi Dọn Sạch Sản Phẩm Trung Gian",
|
||||
"enableInvisibleWatermark": "Bật Chế Độ Ẩn Watermark"
|
||||
"enableInvisibleWatermark": "Bật Chế Độ Ẩn Watermark",
|
||||
"showDetailedInvocationProgress": "Hiện Dữ Liệu Xử Lý"
|
||||
},
|
||||
"sdxl": {
|
||||
"loading": "Đang Tải...",
|
||||
@@ -1594,7 +1614,7 @@
|
||||
"pullBboxIntoLayerError": "Có Vấn Đề Khi Chuyển Hộp Giới Hạn Thành Layer",
|
||||
"pullBboxIntoReferenceImageOk": "Chuyển Hộp Giới Hạn Thành Ảnh Mẫu",
|
||||
"clearCaches": "Xoá Bộ Nhớ Đệm",
|
||||
"outputOnlyMaskedRegions": "Chỉ Xuất Đầu Ra Ở Vùng Phủ",
|
||||
"outputOnlyMaskedRegions": "Chỉ Xuất Đầu Ra Ở Vùng Tạo Sinh",
|
||||
"addLayer": "Thêm Layer",
|
||||
"regional": "Khu Vực",
|
||||
"regionIsEmpty": "Vùng được chọn trống",
|
||||
@@ -1608,10 +1628,13 @@
|
||||
"moveForward": "Chuyển Lên Đầu",
|
||||
"fitBboxToLayers": "Xếp Vừa Hộp Giới Hạn Vào Layer",
|
||||
"ipAdapterMethod": {
|
||||
"full": "Đầy Đủ",
|
||||
"full": "Phong Cách Và Thành Phần",
|
||||
"style": "Chỉ Lấy Phong Cách",
|
||||
"composition": "Chỉ Lấy Thành Phần",
|
||||
"ipAdapterMethod": "Cách Thức IP Adapter"
|
||||
"ipAdapterMethod": "Cách Thức",
|
||||
"compositionDesc": "Áp dụng cách trình bày và bỏ qua phong cách mẫu.",
|
||||
"fullDesc": "Áp dụng phong cách trực quan (màu, cấu tạo) & thành phần (cách trình bày).",
|
||||
"styleDesc": "Áp dụng phong cách trực quan (màu, cấu tạo) và bỏ qua cách trình bày."
|
||||
},
|
||||
"deletePrompt": "Xoá Lệnh",
|
||||
"rasterLayer": "Layer Dạng Raster",
|
||||
@@ -1899,7 +1922,16 @@
|
||||
"colorPicker": "Chọn Màu"
|
||||
},
|
||||
"mergingLayers": "Đang gộp layer",
|
||||
"controlLayerEmptyState": "<UploadButton>Tải lên ảnh</UploadButton>, kéo thả ảnh từ <GalleryButton>thư viện</GalleryButton> vào layer này, hoặc vẽ trên canvas để bắt đầu."
|
||||
"controlLayerEmptyState": "<UploadButton>Tải lên ảnh</UploadButton>, kéo thả ảnh từ <GalleryButton>thư viện</GalleryButton> vào layer này, hoặc vẽ trên canvas để bắt đầu.",
|
||||
"referenceImageEmptyState": "<UploadButton>Tải lên ảnh</UploadButton> hoặc kéo thả ảnh từ <GalleryButton>thư viện</GalleryButton> vào layer này để bắt đầu.",
|
||||
"useImage": "Dùng Hình Ảnh",
|
||||
"resetCanvasLayers": "Khởi Động Lại Layer Canvas",
|
||||
"asRasterLayer": "Như $t(controlLayers.rasterLayer)",
|
||||
"asRasterLayerResize": "Như $t(controlLayers.rasterLayer) (Thay Đổi Kích Thước)",
|
||||
"asControlLayer": "Như $t(controlLayers.controlLayer)",
|
||||
"asControlLayerResize": "Như $t(controlLayers.controlLayer) (Thay Đổi Kích Thước)",
|
||||
"newSession": "Phiên Làm Việc Mới",
|
||||
"resetGenerationSettings": "Khởi Động Lại Cài Đặt Tạo Sinh"
|
||||
},
|
||||
"stylePresets": {
|
||||
"negativePrompt": "Lệnh Tiêu Cực",
|
||||
@@ -2124,8 +2156,8 @@
|
||||
"watchRecentReleaseVideos": "Xem Video Phát Hành Mới Nhất",
|
||||
"watchUiUpdatesOverview": "Xem Tổng Quan Về Những Cập Nhật Cho Giao Diện Người Dùng",
|
||||
"items": [
|
||||
"<StrongComponent>SD 3.5</StrongComponent>: Hỗ trợ cho Từ ngữ Sang Hình Ảnh trong Workflow với phiên bản SD 3.5 Medium hoặc Large.",
|
||||
"<StrongComponent>Canvas</StrongComponent>: Hợp lý hoá cách xử lý Layer Điều Khiển Được và cải thiện thiết lập điều khiển mặc định."
|
||||
"<StrongComponent>Workflows</StrongComponent>: Chạy một workflow cho nhiều ảnh bằng node <StrongComponent>Ảnh Hàng Loạt</StrongComponent> mới.",
|
||||
"<StrongComponent>FLUX</StrongComponent>: Hỗ trợ cho XLabs IP Adapter v2."
|
||||
]
|
||||
},
|
||||
"upsell": {
|
||||
@@ -2133,5 +2165,67 @@
|
||||
"inviteTeammates": "Thêm Đồng Đội",
|
||||
"shareAccess": "Chia Sẻ Quyền Truy Cập",
|
||||
"professionalUpsell": "Không có sẵn Phiên Bản Chuyên Nghiệp cho Invoke. Bấm vào đây hoặc đến invoke.com/pricing để thêm chi tiết."
|
||||
},
|
||||
"supportVideos": {
|
||||
"supportVideos": "Video Hỗ Trợ",
|
||||
"gettingStarted": "Bắt Đầu Làm Quen",
|
||||
"studioSessionsDesc1": "Xem thử <StudioSessionsPlaylistLink /> để hiểu rõ Invoke hơn.",
|
||||
"studioSessionsDesc2": "Đến <DiscordLink /> để tham gia vào phiên trực tiếp và hỏi câu hỏi. Các phiên được tải lên danh sách phát vào các tuần.",
|
||||
"videos": {
|
||||
"howDoIDoImageToImageTransformation": {
|
||||
"title": "Làm Sao Để Tôi Dùng Trình Biến Đổi Hình Ảnh Sang Hình Ảnh?",
|
||||
"description": "Hướng dẫn cách thực hiện biến đổi ảnh sang ảnh trong Invoke."
|
||||
},
|
||||
"howDoIUseGlobalIPAdaptersAndReferenceImages": {
|
||||
"description": "Giới thiệu về ảnh mẫu và IP adapter toàn vùng.",
|
||||
"title": "Làm Sao Để Tôi Dùng IP Adapter Toàn Vùng Và Ảnh Mẫu?"
|
||||
},
|
||||
"creatingAndComposingOnInvokesControlCanvas": {
|
||||
"description": "Học cách sáng tạo ảnh bằng trình điều khiển canvas của Invoke.",
|
||||
"title": "Sáng Tạo Trong Trình Kiểm Soát Canvas Của Invoke"
|
||||
},
|
||||
"upscaling": {
|
||||
"description": "Cách upscale ảnh bằng bộ công cụ của Invoke để nâng cấp độ phân giải.",
|
||||
"title": "Upscale (Nâng Cấp Chất Lượng Hình Ảnh)"
|
||||
},
|
||||
"howDoIGenerateAndSaveToTheGallery": {
|
||||
"title": "Làm Sao Để Tôi Tạo Sinh Và Lưu Vào Thư Viện?",
|
||||
"description": "Các bước để tạo sinh và lưu ảnh vào thư viện."
|
||||
},
|
||||
"howDoIEditOnTheCanvas": {
|
||||
"description": "Hướng dẫn chỉnh sửa ảnh trực tiếp trên canvas.",
|
||||
"title": "Làm Sao Để Tôi Chỉnh Sửa Trên Canvas?"
|
||||
},
|
||||
"howDoIUseControlNetsAndControlLayers": {
|
||||
"title": "Làm Sao Để Tôi Dùng ControlNet và Layer Điều Khiển Được?",
|
||||
"description": "Học cách áp dụng layer điều khiển được và controlnet vào ảnh của bạn."
|
||||
},
|
||||
"howDoIUseInpaintMasks": {
|
||||
"title": "Làm Sao Để Tôi Dùng Lớp Phủ Inpaint?",
|
||||
"description": "Cách áp dụng lớp phủ inpaint vào chỉnh sửa và thay đổi ảnh."
|
||||
},
|
||||
"howDoIOutpaint": {
|
||||
"title": "Làm Sao Để Tôi Outpaint?",
|
||||
"description": "Hướng dẫn outpaint bên ngoài viền ảnh gốc."
|
||||
},
|
||||
"creatingYourFirstImage": {
|
||||
"description": "Giới thiệu về cách tạo ảnh từ ban đầu bằng công cụ Invoke.",
|
||||
"title": "Tạo Hình Ảnh Đầu Tiên Của Bạn"
|
||||
},
|
||||
"usingControlLayersAndReferenceGuides": {
|
||||
"description": "Học cách chỉ dẫn ảnh được tạo ra bằng layer điều khiển được và ảnh mẫu.",
|
||||
"title": "Dùng Layer Điều Khiển Được và Chỉ Dẫn Mẫu"
|
||||
},
|
||||
"understandingImageToImageAndDenoising": {
|
||||
"title": "Hiểu Rõ Trình Hình Ảnh Sang Hình Ảnh Và Trình Khử Nhiễu",
|
||||
"description": "Tổng quan về trình biến đổi ảnh sang ảnh và trình khử nhiễu trong Invoke."
|
||||
},
|
||||
"exploringAIModelsAndConceptAdapters": {
|
||||
"title": "Khám Phá Model AI Và Khái Niệm Về Adapter",
|
||||
"description": "Đào sâu vào model AI và cách dùng những adapter để điều khiển một cách sáng tạo."
|
||||
}
|
||||
},
|
||||
"controlCanvas": "Điều Khiển Canvas",
|
||||
"watch": "Xem"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -668,7 +668,6 @@
|
||||
"controlAdapterIncompatibleBaseModel": "Control Adapter的基础模型不兼容",
|
||||
"ipAdapterIncompatibleBaseModel": "IP Adapter的基础模型不兼容",
|
||||
"ipAdapterNoImageSelected": "未选择IP Adapter图像",
|
||||
"rgNoRegion": "未选择区域",
|
||||
"t2iAdapterIncompatibleBboxWidth": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}},边界框宽度为 {{width}}",
|
||||
"t2iAdapterIncompatibleScaledBboxHeight": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}},缩放后的边界框高度为 {{height}}",
|
||||
"t2iAdapterIncompatibleBboxHeight": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}},边界框高度为 {{height}}",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useClearQueue } from 'features/queue/components/ClearQueueConfirmationAlertDialog';
|
||||
import { useCancelCurrentQueueItem } from 'features/queue/hooks/useCancelCurrentQueueItem';
|
||||
import { useClearQueue } from 'features/queue/hooks/useClearQueue';
|
||||
import { useInvoke } from 'features/queue/hooks/useInvoke';
|
||||
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
|
||||
@@ -63,7 +63,7 @@ export const CanvasAddEntityButtons = memo(() => {
|
||||
justifyContent="flex-start"
|
||||
leftIcon={<PiPlusBold />}
|
||||
onClick={addRegionalGuidance}
|
||||
isDisabled={isFLUX || isSD3}
|
||||
isDisabled={isSD3}
|
||||
>
|
||||
{t('controlLayers.regionalGuidance')}
|
||||
</Button>
|
||||
|
||||
@@ -49,7 +49,7 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addInpaintMask}>
|
||||
{t('controlLayers.inpaintMask')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addRegionalGuidance} isDisabled={isFLUX || isSD3}>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addRegionalGuidance} isDisabled={isSD3}>
|
||||
{t('controlLayers.regionalGuidance')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addRegionalReferenceImage} isDisabled={isFLUX || isSD3}>
|
||||
|
||||
@@ -1,27 +1,28 @@
|
||||
import { IconButton, Tooltip } from '@invoke-ai/ui-library';
|
||||
import type { IconButtonProps } from '@invoke-ai/ui-library';
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiTrashSimpleFill } from 'react-icons/pi';
|
||||
import { PiXBold } from 'react-icons/pi';
|
||||
|
||||
type Props = {
|
||||
type Props = Omit<IconButtonProps, 'aria-label'> & {
|
||||
onDelete: () => void;
|
||||
};
|
||||
|
||||
export const RegionalGuidanceDeletePromptButton = memo(({ onDelete }: Props) => {
|
||||
export const RegionalGuidanceDeletePromptButton = memo(({ onDelete, ...rest }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
return (
|
||||
<Tooltip label={t('controlLayers.deletePrompt')}>
|
||||
<IconButton
|
||||
variant="link"
|
||||
aria-label={t('controlLayers.deletePrompt')}
|
||||
icon={<PiTrashSimpleFill />}
|
||||
onClick={onDelete}
|
||||
flexGrow={0}
|
||||
size="sm"
|
||||
p={0}
|
||||
colorScheme="error"
|
||||
/>
|
||||
</Tooltip>
|
||||
<IconButton
|
||||
tooltip={t('common.delete')}
|
||||
variant="link"
|
||||
aria-label={t('common.delete')}
|
||||
icon={<PiXBold />}
|
||||
onClick={onDelete}
|
||||
flexGrow={0}
|
||||
size="sm"
|
||||
p={0}
|
||||
colorScheme="error"
|
||||
{...rest}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import { Button, Flex, Text } from '@invoke-ai/ui-library';
|
||||
import { Button, Flex, IconButton, Spacer, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
|
||||
import { rgIPAdapterDeleted } from 'features/controlLayers/store/canvasSlice';
|
||||
import type { SetRegionalGuidanceReferenceImageDndTargetData } from 'features/dnd/dnd';
|
||||
import { setRegionalGuidanceReferenceImageDndTarget } from 'features/dnd/dnd';
|
||||
import { DndDropTarget } from 'features/dnd/DndDropTarget';
|
||||
@@ -10,6 +11,7 @@ import { setRegionalGuidanceReferenceImage } from 'features/imageActions/actions
|
||||
import { activeTabCanvasRightPanelChanged } from 'features/ui/store/uiSlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { Trans, useTranslation } from 'react-i18next';
|
||||
import { PiXBold } from 'react-icons/pi';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
type Props = {
|
||||
@@ -31,6 +33,9 @@ export const RegionalGuidanceIPAdapterSettingsEmptyState = memo(({ referenceImag
|
||||
const onClickGalleryButton = useCallback(() => {
|
||||
dispatch(activeTabCanvasRightPanelChanged('gallery'));
|
||||
}, [dispatch]);
|
||||
const onDeleteIPAdapter = useCallback(() => {
|
||||
dispatch(rgIPAdapterDeleted({ entityIdentifier, referenceImageId }));
|
||||
}, [dispatch, entityIdentifier, referenceImageId]);
|
||||
|
||||
const dndTargetData = useMemo<SetRegionalGuidanceReferenceImageDndTargetData>(
|
||||
() =>
|
||||
@@ -42,26 +47,44 @@ export const RegionalGuidanceIPAdapterSettingsEmptyState = memo(({ referenceImag
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" gap={3} position="relative" w="full" p={4}>
|
||||
<Text textAlign="center" color="base.300">
|
||||
<Trans
|
||||
i18nKey="controlLayers.referenceImageEmptyState"
|
||||
components={{
|
||||
UploadButton: (
|
||||
<Button
|
||||
isDisabled={isBusy}
|
||||
size="sm"
|
||||
variant="link"
|
||||
color="base.300"
|
||||
{...uploadApi.getUploadButtonProps()}
|
||||
/>
|
||||
),
|
||||
GalleryButton: (
|
||||
<Button onClick={onClickGalleryButton} isDisabled={isBusy} size="sm" variant="link" color="base.300" />
|
||||
),
|
||||
}}
|
||||
<Flex flexDir="column" gap={2} position="relative" w="full">
|
||||
<Flex alignItems="center" gap={2}>
|
||||
<Text fontWeight="semibold" color="base.400">
|
||||
{t('controlLayers.referenceImage')}
|
||||
</Text>
|
||||
<Spacer />
|
||||
<IconButton
|
||||
size="sm"
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
icon={<PiXBold />}
|
||||
tooltip={t('controlLayers.deleteReferenceImage')}
|
||||
aria-label={t('controlLayers.deleteReferenceImage')}
|
||||
onClick={onDeleteIPAdapter}
|
||||
colorScheme="error"
|
||||
/>
|
||||
</Text>
|
||||
</Flex>
|
||||
<Flex alignItems="center" gap={2} p={4}>
|
||||
<Text textAlign="center" color="base.300">
|
||||
<Trans
|
||||
i18nKey="controlLayers.referenceImageEmptyState"
|
||||
components={{
|
||||
UploadButton: (
|
||||
<Button
|
||||
isDisabled={isBusy}
|
||||
size="sm"
|
||||
variant="link"
|
||||
color="base.300"
|
||||
{...uploadApi.getUploadButtonProps()}
|
||||
/>
|
||||
),
|
||||
GalleryButton: (
|
||||
<Button onClick={onClickGalleryButton} isDisabled={isBusy} size="sm" variant="link" color="base.300" />
|
||||
),
|
||||
}}
|
||||
/>
|
||||
</Text>
|
||||
</Flex>
|
||||
<input {...uploadApi.getUploadInputProps()} />
|
||||
<DndDropTarget
|
||||
dndTarget={setRegionalGuidanceReferenceImageDndTarget}
|
||||
|
||||
@@ -5,6 +5,7 @@ import { StagingAreaToolbarDiscardSelectedButton } from 'features/controlLayers/
|
||||
import { StagingAreaToolbarImageCountButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarImageCountButton';
|
||||
import { StagingAreaToolbarNextButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarNextButton';
|
||||
import { StagingAreaToolbarPrevButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarPrevButton';
|
||||
import { StagingAreaToolbarSaveAsMenu } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarSaveAsMenu';
|
||||
import { StagingAreaToolbarSaveSelectedToGalleryButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarSaveSelectedToGalleryButton';
|
||||
import { StagingAreaToolbarToggleShowResultsButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarToggleShowResultsButton';
|
||||
import { memo } from 'react';
|
||||
@@ -21,6 +22,7 @@ export const StagingAreaToolbar = memo(() => {
|
||||
<StagingAreaToolbarAcceptButton />
|
||||
<StagingAreaToolbarToggleShowResultsButton />
|
||||
<StagingAreaToolbarSaveSelectedToGalleryButton />
|
||||
<StagingAreaToolbarSaveAsMenu />
|
||||
<StagingAreaToolbarDiscardSelectedButton />
|
||||
<StagingAreaToolbarDiscardAllButton />
|
||||
</ButtonGroup>
|
||||
|
||||
@@ -0,0 +1,136 @@
|
||||
import { IconButton, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
|
||||
import { useAppStore } from 'app/store/nanostores/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { NewLayerIcon } from 'features/controlLayers/components/common/icons';
|
||||
import { selectSelectedImage } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { createNewCanvasEntityFromImage } from 'features/imageActions/actions';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiDotsThreeBold } from 'react-icons/pi';
|
||||
import { imageDTOToFile, uploadImage } from 'services/api/endpoints/images';
|
||||
|
||||
const uploadImageArg = { image_category: 'general', is_intermediate: true, silent: true } as const;
|
||||
|
||||
export const StagingAreaToolbarSaveAsMenu = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const selectedImage = useAppSelector(selectSelectedImage);
|
||||
const store = useAppStore();
|
||||
|
||||
const onClickNewRasterLayerFromImage = useCallback(async () => {
|
||||
if (!selectedImage) {
|
||||
return;
|
||||
}
|
||||
const { dispatch, getState } = store;
|
||||
const file = await imageDTOToFile(selectedImage.imageDTO);
|
||||
const imageDTO = await uploadImage({ file, ...uploadImageArg });
|
||||
createNewCanvasEntityFromImage({
|
||||
imageDTO,
|
||||
type: 'raster_layer',
|
||||
dispatch,
|
||||
getState,
|
||||
overrides: { isEnabled: false }, // We are adding the layer while staging, it should be disabled by default
|
||||
});
|
||||
toast({
|
||||
id: 'SENT_TO_CANVAS',
|
||||
title: t('toast.sentToCanvas'),
|
||||
status: 'success',
|
||||
});
|
||||
}, [selectedImage, store, t]);
|
||||
|
||||
const onClickNewControlLayerFromImage = useCallback(async () => {
|
||||
if (!selectedImage) {
|
||||
return;
|
||||
}
|
||||
const { dispatch, getState } = store;
|
||||
const file = await imageDTOToFile(selectedImage.imageDTO);
|
||||
const imageDTO = await uploadImage({ file, ...uploadImageArg });
|
||||
createNewCanvasEntityFromImage({
|
||||
imageDTO,
|
||||
type: 'control_layer',
|
||||
dispatch,
|
||||
getState,
|
||||
overrides: { isEnabled: false }, // We are adding the layer while staging, it should be disabled by default
|
||||
});
|
||||
toast({
|
||||
id: 'SENT_TO_CANVAS',
|
||||
title: t('toast.sentToCanvas'),
|
||||
status: 'success',
|
||||
});
|
||||
}, [selectedImage, store, t]);
|
||||
|
||||
const onClickNewInpaintMaskFromImage = useCallback(async () => {
|
||||
if (!selectedImage) {
|
||||
return;
|
||||
}
|
||||
const { dispatch, getState } = store;
|
||||
const file = await imageDTOToFile(selectedImage.imageDTO);
|
||||
const imageDTO = await uploadImage({ file, ...uploadImageArg });
|
||||
createNewCanvasEntityFromImage({
|
||||
imageDTO,
|
||||
type: 'inpaint_mask',
|
||||
dispatch,
|
||||
getState,
|
||||
overrides: { isEnabled: false }, // We are adding the layer while staging, it should be disabled by default
|
||||
});
|
||||
toast({
|
||||
id: 'SENT_TO_CANVAS',
|
||||
title: t('toast.sentToCanvas'),
|
||||
status: 'success',
|
||||
});
|
||||
}, [selectedImage, store, t]);
|
||||
|
||||
const onClickNewRegionalGuidanceFromImage = useCallback(async () => {
|
||||
if (!selectedImage) {
|
||||
return;
|
||||
}
|
||||
const { dispatch, getState } = store;
|
||||
const file = await imageDTOToFile(selectedImage.imageDTO);
|
||||
const imageDTO = await uploadImage({ file, ...uploadImageArg });
|
||||
createNewCanvasEntityFromImage({
|
||||
imageDTO,
|
||||
type: 'regional_guidance',
|
||||
dispatch,
|
||||
getState,
|
||||
overrides: { isEnabled: false }, // We are adding the layer while staging, it should be disabled by default
|
||||
});
|
||||
toast({
|
||||
id: 'SENT_TO_CANVAS',
|
||||
title: t('toast.sentToCanvas'),
|
||||
status: 'success',
|
||||
});
|
||||
}, [selectedImage, store, t]);
|
||||
|
||||
return (
|
||||
<Menu>
|
||||
<MenuButton
|
||||
as={IconButton}
|
||||
aria-label={t('controlLayers.newLayerFromImage')}
|
||||
tooltip={t('controlLayers.newLayerFromImage')}
|
||||
icon={<PiDotsThreeBold />}
|
||||
colorScheme="invokeBlue"
|
||||
isDisabled={!selectedImage}
|
||||
/>
|
||||
<MenuList>
|
||||
<MenuItem icon={<NewLayerIcon />} onClickCapture={onClickNewInpaintMaskFromImage} isDisabled={!selectedImage}>
|
||||
{t('controlLayers.inpaintMask')}
|
||||
</MenuItem>
|
||||
<MenuItem
|
||||
icon={<NewLayerIcon />}
|
||||
onClickCapture={onClickNewRegionalGuidanceFromImage}
|
||||
isDisabled={!selectedImage}
|
||||
>
|
||||
{t('controlLayers.regionalGuidance')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<NewLayerIcon />} onClickCapture={onClickNewControlLayerFromImage} isDisabled={!selectedImage}>
|
||||
{t('controlLayers.controlLayer')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<NewLayerIcon />} onClickCapture={onClickNewRasterLayerFromImage} isDisabled={!selectedImage}>
|
||||
{t('controlLayers.rasterLayer')}
|
||||
</MenuItem>
|
||||
</MenuList>
|
||||
</Menu>
|
||||
);
|
||||
});
|
||||
|
||||
StagingAreaToolbarSaveAsMenu.displayName = 'StagingAreaToolbarSaveAsMenu';
|
||||
@@ -1,6 +1,4 @@
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { $authToken } from 'app/store/nanostores/authToken';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { withResultAsync } from 'common/util/result';
|
||||
import { selectSelectedImage } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
@@ -9,14 +7,13 @@ import { toast } from 'features/toast/toast';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiFloppyDiskBold } from 'react-icons/pi';
|
||||
import { uploadImage } from 'services/api/endpoints/images';
|
||||
import { imageDTOToFile, uploadImage } from 'services/api/endpoints/images';
|
||||
|
||||
const TOAST_ID = 'SAVE_STAGING_AREA_IMAGE_TO_GALLERY';
|
||||
|
||||
export const StagingAreaToolbarSaveSelectedToGalleryButton = memo(() => {
|
||||
const autoAddBoardId = useAppSelector(selectAutoAddBoardId);
|
||||
const selectedImage = useAppSelector(selectSelectedImage);
|
||||
const authToken = useStore($authToken);
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
@@ -28,18 +25,8 @@ export const StagingAreaToolbarSaveSelectedToGalleryButton = memo(() => {
|
||||
// To save the image to gallery, we will download it and re-upload it. This allows the user to delete the image
|
||||
// the gallery without borking the canvas, which may need this image to exist.
|
||||
const result = await withResultAsync(async () => {
|
||||
// Download the image
|
||||
const requestOpts = authToken
|
||||
? {
|
||||
headers: {
|
||||
Authorization: `Bearer ${authToken}`,
|
||||
},
|
||||
}
|
||||
: {};
|
||||
const res = await fetch(selectedImage.imageDTO.image_url, requestOpts);
|
||||
const blob = await res.blob();
|
||||
// Create a new file with the same name, which we will upload
|
||||
const file = new File([blob], `copy_of_${selectedImage.imageDTO.image_name}`, { type: 'image/png' });
|
||||
const file = await imageDTOToFile(selectedImage.imageDTO);
|
||||
|
||||
await uploadImage({
|
||||
file,
|
||||
@@ -66,7 +53,7 @@ export const StagingAreaToolbarSaveSelectedToGalleryButton = memo(() => {
|
||||
status: 'error',
|
||||
});
|
||||
}
|
||||
}, [autoAddBoardId, selectedImage, t, authToken]);
|
||||
}, [autoAddBoardId, selectedImage, t]);
|
||||
|
||||
return (
|
||||
<IconButton
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { CanvasEntityDeleteButton } from 'features/controlLayers/components/common/CanvasEntityDeleteButton';
|
||||
import { CanvasEntityEnabledToggle } from 'features/controlLayers/components/common/CanvasEntityEnabledToggle';
|
||||
import { CanvasEntityHeaderWarnings } from 'features/controlLayers/components/common/CanvasEntityHeaderWarnings';
|
||||
import { CanvasEntityIsBookmarkedForQuickSwitchToggle } from 'features/controlLayers/components/common/CanvasEntityIsBookmarkedForQuickSwitchToggle';
|
||||
import { CanvasEntityIsLockedToggle } from 'features/controlLayers/components/common/CanvasEntityIsLockedToggle';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
@@ -11,6 +12,7 @@ export const CanvasEntityHeaderCommonActions = memo(() => {
|
||||
|
||||
return (
|
||||
<Flex alignSelf="stretch">
|
||||
<CanvasEntityHeaderWarnings />
|
||||
<CanvasEntityIsBookmarkedForQuickSwitchToggle />
|
||||
{entityIdentifier.type !== 'reference_image' && <CanvasEntityIsLockedToggle />}
|
||||
<CanvasEntityEnabledToggle />
|
||||
|
||||
@@ -0,0 +1,101 @@
|
||||
import { Flex, IconButton, ListItem, Text, UnorderedList } from '@invoke-ai/ui-library';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { useEntityIsEnabled } from 'features/controlLayers/hooks/useEntityIsEnabled';
|
||||
import { selectModel } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/store/selectors';
|
||||
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import {
|
||||
getControlLayerWarnings,
|
||||
getGlobalReferenceImageWarnings,
|
||||
getInpaintMaskWarnings,
|
||||
getRasterLayerWarnings,
|
||||
getRegionalGuidanceWarnings,
|
||||
} from 'features/controlLayers/store/validators';
|
||||
import type { TFunction } from 'i18next';
|
||||
import { upperFirst } from 'lodash-es';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiWarningBold } from 'react-icons/pi';
|
||||
import type { Equals } from 'tsafe';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const buildSelectWarnings = (entityIdentifier: CanvasEntityIdentifier, t: TFunction) => {
|
||||
return createSelector(selectCanvasSlice, selectModel, (canvas, model) => {
|
||||
// This component is used within a <CanvasEntityStateGate /> so we can safely assume that the entity exists.
|
||||
// Should never throw.
|
||||
const entity = selectEntityOrThrow(canvas, entityIdentifier, 'CanvasEntityHeaderWarnings');
|
||||
|
||||
let warnings: string[] = [];
|
||||
|
||||
const entityType = entity.type;
|
||||
|
||||
if (entityType === 'control_layer') {
|
||||
warnings = getControlLayerWarnings(entity, model);
|
||||
} else if (entityType === 'regional_guidance') {
|
||||
warnings = getRegionalGuidanceWarnings(entity, model);
|
||||
} else if (entityType === 'inpaint_mask') {
|
||||
warnings = getInpaintMaskWarnings(entity, model);
|
||||
} else if (entityType === 'raster_layer') {
|
||||
warnings = getRasterLayerWarnings(entity, model);
|
||||
} else if (entityType === 'reference_image') {
|
||||
warnings = getGlobalReferenceImageWarnings(entity, model);
|
||||
} else {
|
||||
assert<Equals<typeof entityType, never>>(false, 'Unexpected entity type');
|
||||
}
|
||||
|
||||
// Return a stable reference if there are no warnings
|
||||
if (warnings.length === 0) {
|
||||
return EMPTY_ARRAY;
|
||||
}
|
||||
|
||||
return warnings.map((w) => t(w)).map(upperFirst);
|
||||
});
|
||||
};
|
||||
|
||||
export const CanvasEntityHeaderWarnings = memo(() => {
|
||||
const entityIdentifier = useEntityIdentifierContext();
|
||||
const { t } = useTranslation();
|
||||
const isEnabled = useEntityIsEnabled(entityIdentifier);
|
||||
const selectWarnings = useMemo(() => buildSelectWarnings(entityIdentifier, t), [entityIdentifier, t]);
|
||||
const warnings = useAppSelector(selectWarnings);
|
||||
|
||||
if (warnings.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
// Using IconButton here bc it matches the styling of the actual buttons in the header without any fanagling, but
|
||||
// it's not a button
|
||||
<IconButton
|
||||
as="span"
|
||||
size="sm"
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
aria-label="warnings"
|
||||
tooltip={<TooltipContent warnings={warnings} />}
|
||||
icon={<PiWarningBold />}
|
||||
colorScheme="warning"
|
||||
isDisabled={!isEnabled}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
CanvasEntityHeaderWarnings.displayName = 'CanvasEntityHeaderWarnings';
|
||||
|
||||
const TooltipContent = memo((props: { warnings: string[] }) => {
|
||||
const { t } = useTranslation();
|
||||
return (
|
||||
<Flex flexDir="column">
|
||||
<Text>{t('controlLayers.warnings.problemsFound')}:</Text>
|
||||
<UnorderedList>
|
||||
{props.warnings.map((warning, index) => (
|
||||
<ListItem key={index}>{warning}</ListItem>
|
||||
))}
|
||||
</UnorderedList>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
TooltipContent.displayName = 'TooltipContent';
|
||||
@@ -29,7 +29,13 @@ import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/
|
||||
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||
import { isControlNetOrT2IAdapterModelConfig, isIPAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
/** @knipignore */
|
||||
/**
|
||||
* Selects the default control adapter configuration based on the model configurations and the base.
|
||||
*
|
||||
* Be sure to clone the output of this selector before modifying it!
|
||||
*
|
||||
* @knipignore
|
||||
*/
|
||||
export const selectDefaultControlAdapter = createSelector(
|
||||
selectModelConfigsQuery,
|
||||
selectBase,
|
||||
@@ -52,6 +58,11 @@ export const selectDefaultControlAdapter = createSelector(
|
||||
}
|
||||
);
|
||||
|
||||
/**
|
||||
* Selects the default IP adapter configuration based on the model configurations and the base.
|
||||
*
|
||||
* Be sure to clone the output of this selector before modifying it!
|
||||
*/
|
||||
export const selectDefaultIPAdapter = createSelector(
|
||||
selectModelConfigsQuery,
|
||||
selectBase,
|
||||
@@ -117,7 +128,9 @@ export const useAddRegionalReferenceImage = () => {
|
||||
|
||||
const func = useCallback(() => {
|
||||
const overrides: Partial<CanvasRegionalGuidanceState> = {
|
||||
referenceImages: [{ id: getPrefixedId('regional_guidance_reference_image'), ipAdapter: defaultIPAdapter }],
|
||||
referenceImages: [
|
||||
{ id: getPrefixedId('regional_guidance_reference_image'), ipAdapter: deepClone(defaultIPAdapter) },
|
||||
],
|
||||
};
|
||||
dispatch(rgAdded({ isSelected: true, overrides }));
|
||||
}, [defaultIPAdapter, dispatch]);
|
||||
@@ -129,7 +142,7 @@ export const useAddGlobalReferenceImage = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const defaultIPAdapter = useAppSelector(selectDefaultIPAdapter);
|
||||
const func = useCallback(() => {
|
||||
const overrides = { ipAdapter: defaultIPAdapter };
|
||||
const overrides = { ipAdapter: deepClone(defaultIPAdapter) };
|
||||
dispatch(referenceImageAdded({ isSelected: true, overrides }));
|
||||
}, [defaultIPAdapter, dispatch]);
|
||||
|
||||
@@ -140,7 +153,7 @@ export const useAddRegionalGuidanceIPAdapter = (entityIdentifier: CanvasEntityId
|
||||
const dispatch = useAppDispatch();
|
||||
const defaultIPAdapter = useAppSelector(selectDefaultIPAdapter);
|
||||
const func = useCallback(() => {
|
||||
dispatch(rgIPAdapterAdded({ entityIdentifier, overrides: { ipAdapter: defaultIPAdapter } }));
|
||||
dispatch(rgIPAdapterAdded({ entityIdentifier, overrides: { ipAdapter: deepClone(defaultIPAdapter) } }));
|
||||
}, [defaultIPAdapter, dispatch, entityIdentifier]);
|
||||
|
||||
return func;
|
||||
|
||||
@@ -0,0 +1,153 @@
|
||||
import type {
|
||||
CanvasControlLayerState,
|
||||
CanvasInpaintMaskState,
|
||||
CanvasRasterLayerState,
|
||||
CanvasReferenceImageState,
|
||||
CanvasRegionalGuidanceState,
|
||||
} from 'features/controlLayers/store/types';
|
||||
import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
|
||||
|
||||
const WARNINGS = {
|
||||
UNSUPPORTED_MODEL: 'controlLayers.warnings.unsupportedModel',
|
||||
RG_NO_PROMPTS_OR_IP_ADAPTERS: 'controlLayers.warnings.rgNoPromptsOrIPAdapters',
|
||||
RG_NEGATIVE_PROMPT_NOT_SUPPORTED: 'controlLayers.warnings.rgNegativePromptNotSupported',
|
||||
RG_REFERENCE_IMAGES_NOT_SUPPORTED: 'controlLayers.warnings.rgReferenceImagesNotSupported',
|
||||
RG_AUTO_NEGATIVE_NOT_SUPPORTED: 'controlLayers.warnings.rgAutoNegativeNotSupported',
|
||||
RG_NO_REGION: 'controlLayers.warnings.rgNoRegion',
|
||||
IP_ADAPTER_NO_MODEL_SELECTED: 'controlLayers.warnings.ipAdapterNoModelSelected',
|
||||
IP_ADAPTER_INCOMPATIBLE_BASE_MODEL: 'controlLayers.warnings.ipAdapterIncompatibleBaseModel',
|
||||
IP_ADAPTER_NO_IMAGE_SELECTED: 'controlLayers.warnings.ipAdapterNoImageSelected',
|
||||
CONTROL_ADAPTER_NO_MODEL_SELECTED: 'controlLayers.warnings.controlAdapterNoModelSelected',
|
||||
CONTROL_ADAPTER_INCOMPATIBLE_BASE_MODEL: 'controlLayers.warnings.controlAdapterIncompatibleBaseModel',
|
||||
CONTROL_ADAPTER_NO_CONTROL: 'controlLayers.warnings.controlAdapterNoControl',
|
||||
} as const;
|
||||
|
||||
type WarningTKey = (typeof WARNINGS)[keyof typeof WARNINGS];
|
||||
|
||||
export const getRegionalGuidanceWarnings = (
|
||||
entity: CanvasRegionalGuidanceState,
|
||||
model: ParameterModel | null
|
||||
): WarningTKey[] => {
|
||||
const warnings: WarningTKey[] = [];
|
||||
|
||||
if (entity.objects.length === 0) {
|
||||
// Layer is in empty state
|
||||
warnings.push(WARNINGS.RG_NO_REGION);
|
||||
}
|
||||
|
||||
if (entity.positivePrompt === null && entity.negativePrompt === null && entity.referenceImages.length === 0) {
|
||||
// Must have at least 1 prompt or IP Adapter
|
||||
warnings.push(WARNINGS.RG_NO_PROMPTS_OR_IP_ADAPTERS);
|
||||
}
|
||||
|
||||
if (model) {
|
||||
if (model.base === 'sd-3' || model.base === 'sd-2') {
|
||||
// Unsupported model architecture
|
||||
warnings.push(WARNINGS.UNSUPPORTED_MODEL);
|
||||
} else if (model.base === 'flux') {
|
||||
// Some features are not supported for flux models
|
||||
if (entity.negativePrompt !== null) {
|
||||
warnings.push(WARNINGS.RG_NEGATIVE_PROMPT_NOT_SUPPORTED);
|
||||
}
|
||||
if (entity.referenceImages.length > 0) {
|
||||
warnings.push(WARNINGS.RG_REFERENCE_IMAGES_NOT_SUPPORTED);
|
||||
}
|
||||
if (entity.autoNegative) {
|
||||
warnings.push(WARNINGS.RG_AUTO_NEGATIVE_NOT_SUPPORTED);
|
||||
}
|
||||
} else {
|
||||
entity.referenceImages.forEach(({ ipAdapter }) => {
|
||||
if (!ipAdapter.model) {
|
||||
// No model selected
|
||||
warnings.push(WARNINGS.IP_ADAPTER_NO_MODEL_SELECTED);
|
||||
} else if (ipAdapter.model.base !== model.base) {
|
||||
// Supported model architecture but doesn't match
|
||||
warnings.push(WARNINGS.IP_ADAPTER_INCOMPATIBLE_BASE_MODEL);
|
||||
}
|
||||
|
||||
if (!ipAdapter.image) {
|
||||
// No image selected
|
||||
warnings.push(WARNINGS.IP_ADAPTER_NO_IMAGE_SELECTED);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return warnings;
|
||||
};
|
||||
|
||||
export const getGlobalReferenceImageWarnings = (
|
||||
entity: CanvasReferenceImageState,
|
||||
model: ParameterModel | null
|
||||
): WarningTKey[] => {
|
||||
const warnings: WarningTKey[] = [];
|
||||
|
||||
if (!entity.ipAdapter.model) {
|
||||
// No model selected
|
||||
warnings.push(WARNINGS.IP_ADAPTER_NO_MODEL_SELECTED);
|
||||
} else if (model) {
|
||||
if (model.base === 'sd-3' || model.base === 'sd-2') {
|
||||
// Unsupported model architecture
|
||||
warnings.push(WARNINGS.UNSUPPORTED_MODEL);
|
||||
} else if (entity.ipAdapter.model.base !== model.base) {
|
||||
// Supported model architecture but doesn't match
|
||||
warnings.push(WARNINGS.IP_ADAPTER_INCOMPATIBLE_BASE_MODEL);
|
||||
}
|
||||
}
|
||||
|
||||
if (!entity.ipAdapter.image) {
|
||||
// No image selected
|
||||
warnings.push(WARNINGS.IP_ADAPTER_NO_IMAGE_SELECTED);
|
||||
}
|
||||
|
||||
return warnings;
|
||||
};
|
||||
|
||||
export const getControlLayerWarnings = (
|
||||
entity: CanvasControlLayerState,
|
||||
model: ParameterModel | null
|
||||
): WarningTKey[] => {
|
||||
const warnings: WarningTKey[] = [];
|
||||
|
||||
if (entity.objects.length === 0) {
|
||||
// Layer is in empty state
|
||||
warnings.push(WARNINGS.CONTROL_ADAPTER_NO_CONTROL);
|
||||
}
|
||||
|
||||
if (!entity.controlAdapter.model) {
|
||||
// No model selected
|
||||
warnings.push(WARNINGS.CONTROL_ADAPTER_NO_MODEL_SELECTED);
|
||||
} else if (model) {
|
||||
if (model.base === 'sd-3' || model.base === 'sd-2') {
|
||||
// Unsupported model architecture
|
||||
warnings.push(WARNINGS.UNSUPPORTED_MODEL);
|
||||
} else if (entity.controlAdapter.model.base !== model.base) {
|
||||
// Supported model architecture but doesn't match
|
||||
warnings.push(WARNINGS.CONTROL_ADAPTER_INCOMPATIBLE_BASE_MODEL);
|
||||
}
|
||||
}
|
||||
|
||||
return warnings;
|
||||
};
|
||||
|
||||
export const getRasterLayerWarnings = (
|
||||
_entity: CanvasRasterLayerState,
|
||||
_model: ParameterModel | null
|
||||
): WarningTKey[] => {
|
||||
const warnings: WarningTKey[] = [];
|
||||
|
||||
// There are no warnings at the moment for raster layers.
|
||||
|
||||
return warnings;
|
||||
};
|
||||
|
||||
export const getInpaintMaskWarnings = (
|
||||
_entity: CanvasInpaintMaskState,
|
||||
_model: ParameterModel | null
|
||||
): WarningTKey[] => {
|
||||
const warnings: WarningTKey[] = [];
|
||||
|
||||
// There are no warnings at the moment for inpaint masks.
|
||||
|
||||
return warnings;
|
||||
};
|
||||
@@ -77,6 +77,32 @@ export const ImageMenuItemNewLayerFromImageSubMenu = memo(() => {
|
||||
});
|
||||
}, [imageDTO, imageViewer, store, t]);
|
||||
|
||||
const onClickNewRegionalReferenceImageFromImage = useCallback(() => {
|
||||
const { dispatch, getState } = store;
|
||||
createNewCanvasEntityFromImage({ imageDTO, type: 'reference_image', dispatch, getState });
|
||||
dispatch(sentImageToCanvas());
|
||||
dispatch(setActiveTab('canvas'));
|
||||
imageViewer.close();
|
||||
toast({
|
||||
id: 'SENT_TO_CANVAS',
|
||||
title: t('toast.sentToCanvas'),
|
||||
status: 'success',
|
||||
});
|
||||
}, [imageDTO, imageViewer, store, t]);
|
||||
|
||||
const onClickNewGlobalReferenceImageFromImage = useCallback(() => {
|
||||
const { dispatch, getState } = store;
|
||||
createNewCanvasEntityFromImage({ imageDTO, type: 'regional_guidance_with_reference_image', dispatch, getState });
|
||||
dispatch(sentImageToCanvas());
|
||||
dispatch(setActiveTab('canvas'));
|
||||
imageViewer.close();
|
||||
toast({
|
||||
id: 'SENT_TO_CANVAS',
|
||||
title: t('toast.sentToCanvas'),
|
||||
status: 'success',
|
||||
});
|
||||
}, [imageDTO, imageViewer, store, t]);
|
||||
|
||||
return (
|
||||
<MenuItem {...subMenu.parentMenuItemProps} icon={<PiPlusBold />}>
|
||||
<Menu {...subMenu.menuProps}>
|
||||
@@ -104,6 +130,20 @@ export const ImageMenuItemNewLayerFromImageSubMenu = memo(() => {
|
||||
<MenuItem icon={<NewLayerIcon />} onClickCapture={onClickNewRasterLayerFromImage} isDisabled={isBusy}>
|
||||
{t('controlLayers.rasterLayer')}
|
||||
</MenuItem>
|
||||
<MenuItem
|
||||
icon={<NewLayerIcon />}
|
||||
onClickCapture={onClickNewRegionalReferenceImageFromImage}
|
||||
isDisabled={isBusy}
|
||||
>
|
||||
{t('controlLayers.referenceImageRegional')}
|
||||
</MenuItem>
|
||||
<MenuItem
|
||||
icon={<NewLayerIcon />}
|
||||
onClickCapture={onClickNewGlobalReferenceImageFromImage}
|
||||
isDisabled={isBusy}
|
||||
>
|
||||
{t('controlLayers.referenceImageGlobal')}
|
||||
</MenuItem>
|
||||
</MenuList>
|
||||
</Menu>
|
||||
</MenuItem>
|
||||
|
||||
@@ -20,6 +20,7 @@ import { selectBboxModelBase, selectBboxRect } from 'features/controlLayers/stor
|
||||
import type {
|
||||
CanvasControlLayerState,
|
||||
CanvasEntityIdentifier,
|
||||
CanvasEntityState,
|
||||
CanvasEntityType,
|
||||
CanvasInpaintMaskState,
|
||||
CanvasRasterLayerState,
|
||||
@@ -134,14 +135,16 @@ export const createNewCanvasEntityFromImage = (arg: {
|
||||
type: CanvasEntityType | 'regional_guidance_with_reference_image';
|
||||
dispatch: AppDispatch;
|
||||
getState: () => RootState;
|
||||
overrides?: Partial<Pick<CanvasEntityState, 'isEnabled' | 'isLocked' | 'name'>>;
|
||||
}) => {
|
||||
const { type, imageDTO, dispatch, getState } = arg;
|
||||
const { type, imageDTO, dispatch, getState, overrides: _overrides } = arg;
|
||||
const state = getState();
|
||||
const imageObject = imageDTOToImageObject(imageDTO);
|
||||
const { x, y } = selectBboxRect(state);
|
||||
const overrides = {
|
||||
objects: [imageObject],
|
||||
position: { x, y },
|
||||
..._overrides,
|
||||
};
|
||||
switch (type) {
|
||||
case 'raster_layer': {
|
||||
@@ -166,13 +169,13 @@ export const createNewCanvasEntityFromImage = (arg: {
|
||||
break;
|
||||
}
|
||||
case 'reference_image': {
|
||||
const ipAdapter = selectDefaultIPAdapter(getState());
|
||||
const ipAdapter = deepClone(selectDefaultIPAdapter(getState()));
|
||||
ipAdapter.image = imageDTOToImageWithDims(imageDTO);
|
||||
dispatch(referenceImageAdded({ overrides: { ipAdapter }, isSelected: true }));
|
||||
break;
|
||||
}
|
||||
case 'regional_guidance_with_reference_image': {
|
||||
const ipAdapter = selectDefaultIPAdapter(getState());
|
||||
const ipAdapter = deepClone(selectDefaultIPAdapter(getState()));
|
||||
ipAdapter.image = imageDTOToImageWithDims(imageDTO);
|
||||
const referenceImages = [{ id: getPrefixedId('regional_guidance_reference_image'), ipAdapter }];
|
||||
dispatch(rgAdded({ overrides: { referenceImages }, isSelected: true }));
|
||||
@@ -288,14 +291,14 @@ export const newCanvasFromImage = (arg: {
|
||||
break;
|
||||
}
|
||||
case 'reference_image': {
|
||||
const ipAdapter = selectDefaultIPAdapter(getState());
|
||||
const ipAdapter = deepClone(selectDefaultIPAdapter(getState()));
|
||||
ipAdapter.image = imageDTOToImageWithDims(imageDTO);
|
||||
dispatch(canvasReset());
|
||||
dispatch(referenceImageAdded({ overrides: { ipAdapter }, isSelected: true }));
|
||||
break;
|
||||
}
|
||||
case 'regional_guidance_with_reference_image': {
|
||||
const ipAdapter = selectDefaultIPAdapter(getState());
|
||||
const ipAdapter = deepClone(selectDefaultIPAdapter(getState()));
|
||||
ipAdapter.image = imageDTOToImageWithDims(imageDTO);
|
||||
const referenceImages = [{ id: getPrefixedId('regional_guidance_reference_image'), ipAdapter }];
|
||||
dispatch(canvasReset());
|
||||
|
||||
@@ -1,35 +1,41 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { withResultAsync } from 'common/util/result';
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import type {
|
||||
CanvasControlLayerState,
|
||||
ControlNetConfig,
|
||||
Rect,
|
||||
T2IAdapterConfig,
|
||||
} from 'features/controlLayers/store/types';
|
||||
import type { CanvasControlLayerState, Rect } from 'features/controlLayers/store/types';
|
||||
import { getControlLayerWarnings } from 'features/controlLayers/store/validators';
|
||||
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import type { BaseModelType, ImageDTO, Invocation } from 'services/api/types';
|
||||
import type { ImageDTO, Invocation } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const log = logger('system');
|
||||
|
||||
type AddControlNetsArg = {
|
||||
manager: CanvasManager;
|
||||
entities: CanvasControlLayerState[];
|
||||
g: Graph;
|
||||
rect: Rect;
|
||||
collector: Invocation<'collect'>;
|
||||
model: ParameterModel;
|
||||
};
|
||||
|
||||
type AddControlNetsResult = {
|
||||
addedControlNets: number;
|
||||
};
|
||||
|
||||
export const addControlNets = async (
|
||||
manager: CanvasManager,
|
||||
layers: CanvasControlLayerState[],
|
||||
g: Graph,
|
||||
rect: Rect,
|
||||
collector: Invocation<'collect'>,
|
||||
base: BaseModelType
|
||||
): Promise<AddControlNetsResult> => {
|
||||
const validControlLayers = layers
|
||||
.filter((layer) => layer.isEnabled)
|
||||
.filter((layer) => isValidControlAdapter(layer.controlAdapter, base))
|
||||
.filter((layer) => layer.controlAdapter.type === 'controlnet');
|
||||
export const addControlNets = async ({
|
||||
manager,
|
||||
entities,
|
||||
g,
|
||||
rect,
|
||||
collector,
|
||||
model,
|
||||
}: AddControlNetsArg): Promise<AddControlNetsResult> => {
|
||||
const validControlLayers = entities
|
||||
.filter((entity) => entity.isEnabled)
|
||||
.filter((entity) => entity.controlAdapter.type === 'controlnet')
|
||||
.filter((entity) => getControlLayerWarnings(entity, model).length === 0);
|
||||
|
||||
const result: AddControlNetsResult = {
|
||||
addedControlNets: 0,
|
||||
@@ -54,22 +60,31 @@ export const addControlNets = async (
|
||||
return result;
|
||||
};
|
||||
|
||||
type AddT2IAdaptersArg = {
|
||||
manager: CanvasManager;
|
||||
entities: CanvasControlLayerState[];
|
||||
g: Graph;
|
||||
rect: Rect;
|
||||
collector: Invocation<'collect'>;
|
||||
model: ParameterModel;
|
||||
};
|
||||
|
||||
type AddT2IAdaptersResult = {
|
||||
addedT2IAdapters: number;
|
||||
};
|
||||
|
||||
export const addT2IAdapters = async (
|
||||
manager: CanvasManager,
|
||||
layers: CanvasControlLayerState[],
|
||||
g: Graph,
|
||||
rect: Rect,
|
||||
collector: Invocation<'collect'>,
|
||||
base: BaseModelType
|
||||
): Promise<AddT2IAdaptersResult> => {
|
||||
const validControlLayers = layers
|
||||
.filter((layer) => layer.isEnabled)
|
||||
.filter((layer) => isValidControlAdapter(layer.controlAdapter, base))
|
||||
.filter((layer) => layer.controlAdapter.type === 't2i_adapter');
|
||||
export const addT2IAdapters = async ({
|
||||
manager,
|
||||
entities,
|
||||
g,
|
||||
rect,
|
||||
collector,
|
||||
model,
|
||||
}: AddT2IAdaptersArg): Promise<AddT2IAdaptersResult> => {
|
||||
const validControlLayers = entities
|
||||
.filter((entity) => entity.isEnabled)
|
||||
.filter((entity) => entity.controlAdapter.type === 't2i_adapter')
|
||||
.filter((entity) => getControlLayerWarnings(entity, model).length === 0);
|
||||
|
||||
const result: AddT2IAdaptersResult = {
|
||||
addedT2IAdapters: 0,
|
||||
@@ -145,11 +160,3 @@ const addT2IAdapterToGraph = (
|
||||
|
||||
g.addEdge(t2iAdapter, 't2i_adapter', collector, 'item');
|
||||
};
|
||||
|
||||
const isValidControlAdapter = (controlAdapter: ControlNetConfig | T2IAdapterConfig, base: BaseModelType): boolean => {
|
||||
// Must be have a model
|
||||
const hasModel = Boolean(controlAdapter.model);
|
||||
// Model must match the current base model
|
||||
const modelMatchesBase = controlAdapter.model?.base === base;
|
||||
return hasModel && modelMatchesBase;
|
||||
};
|
||||
|
||||
@@ -1,19 +1,23 @@
|
||||
import type { CanvasReferenceImageState } from 'features/controlLayers/store/types';
|
||||
import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators';
|
||||
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import type { BaseModelType, Invocation } from 'services/api/types';
|
||||
import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
|
||||
import type { Invocation } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
type AddIPAdaptersResult = {
|
||||
addedIPAdapters: number;
|
||||
};
|
||||
|
||||
export const addIPAdapters = (
|
||||
ipAdapters: CanvasReferenceImageState[],
|
||||
g: Graph,
|
||||
collector: Invocation<'collect'>,
|
||||
base: BaseModelType
|
||||
): AddIPAdaptersResult => {
|
||||
const validIPAdapters = ipAdapters.filter((entity) => isValidIPAdapter(entity, base));
|
||||
type AddIPAdaptersArg = {
|
||||
entities: CanvasReferenceImageState[];
|
||||
g: Graph;
|
||||
collector: Invocation<'collect'>;
|
||||
model: ParameterModel;
|
||||
};
|
||||
|
||||
export const addIPAdapters = ({ entities, g, collector, model }: AddIPAdaptersArg): AddIPAdaptersResult => {
|
||||
const validIPAdapters = entities.filter((entity) => getGlobalReferenceImageWarnings(entity, model).length === 0);
|
||||
|
||||
const result: AddIPAdaptersResult = {
|
||||
addedIPAdapters: 0,
|
||||
@@ -76,11 +80,3 @@ const addIPAdapter = (entity: CanvasReferenceImageState, g: Graph, collector: In
|
||||
|
||||
g.addEdge(ipAdapterNode, 'ip_adapter', collector, 'item');
|
||||
};
|
||||
|
||||
const isValidIPAdapter = ({ isEnabled, ipAdapter }: CanvasReferenceImageState, base: BaseModelType): boolean => {
|
||||
// Must be have a model that matches the current base and must have a control image
|
||||
const hasModel = Boolean(ipAdapter.model);
|
||||
const modelMatchesBase = ipAdapter.model?.base === base;
|
||||
const hasImage = Boolean(ipAdapter.image);
|
||||
return isEnabled && hasModel && modelMatchesBase && hasImage;
|
||||
};
|
||||
|
||||
@@ -3,15 +3,12 @@ import { deepClone } from 'common/util/deepClone';
|
||||
import { withResultAsync } from 'common/util/result';
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import type {
|
||||
CanvasRegionalGuidanceState,
|
||||
IPAdapterConfig,
|
||||
Rect,
|
||||
RegionalGuidanceReferenceImageState,
|
||||
} from 'features/controlLayers/store/types';
|
||||
import type { CanvasRegionalGuidanceState, Rect } from 'features/controlLayers/store/types';
|
||||
import { getRegionalGuidanceWarnings } from 'features/controlLayers/store/validators';
|
||||
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import type { BaseModelType, Invocation } from 'services/api/types';
|
||||
import type { Invocation } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const log = logger('system');
|
||||
@@ -23,19 +20,26 @@ type AddedRegionResult = {
|
||||
addedIPAdapters: number;
|
||||
};
|
||||
|
||||
const isValidRegion = (rg: CanvasRegionalGuidanceState, base: BaseModelType) => {
|
||||
const isEnabled = rg.isEnabled;
|
||||
const hasTextPrompt = Boolean(rg.positivePrompt || rg.negativePrompt);
|
||||
const hasIPAdapter = rg.referenceImages.filter(({ ipAdapter }) => isValidIPAdapter(ipAdapter, base)).length > 0;
|
||||
return isEnabled && (hasTextPrompt || hasIPAdapter);
|
||||
type AddRegionsArg = {
|
||||
manager: CanvasManager;
|
||||
regions: CanvasRegionalGuidanceState[];
|
||||
g: Graph;
|
||||
bbox: Rect;
|
||||
model: ParameterModel;
|
||||
posCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>;
|
||||
negCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'> | null;
|
||||
posCondCollect: Invocation<'collect'>;
|
||||
negCondCollect: Invocation<'collect'> | null;
|
||||
ipAdapterCollect: Invocation<'collect'>;
|
||||
};
|
||||
|
||||
/**
|
||||
* Adds regional guidance to the graph
|
||||
* @param manager The canvas manager
|
||||
* @param regions Array of regions to add
|
||||
* @param g The graph to add the layers to
|
||||
* @param base The base model type
|
||||
* @param denoise The main denoise node
|
||||
* @param bbox The bounding box
|
||||
* @param model The main model
|
||||
* @param posCond The positive conditioning node
|
||||
* @param negCond The negative conditioning node
|
||||
* @param posCondCollect The positive conditioning collector
|
||||
@@ -44,22 +48,28 @@ const isValidRegion = (rg: CanvasRegionalGuidanceState, base: BaseModelType) =>
|
||||
* @returns A promise that resolves to the regions that were successfully added to the graph
|
||||
*/
|
||||
|
||||
export const addRegions = async (
|
||||
manager: CanvasManager,
|
||||
regions: CanvasRegionalGuidanceState[],
|
||||
g: Graph,
|
||||
bbox: Rect,
|
||||
base: BaseModelType,
|
||||
denoise: Invocation<'denoise_latents'>,
|
||||
posCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>,
|
||||
negCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>,
|
||||
posCondCollect: Invocation<'collect'>,
|
||||
negCondCollect: Invocation<'collect'>,
|
||||
ipAdapterCollect: Invocation<'collect'>
|
||||
): Promise<AddedRegionResult[]> => {
|
||||
const isSDXL = base === 'sdxl';
|
||||
export const addRegions = async ({
|
||||
manager,
|
||||
regions,
|
||||
g,
|
||||
bbox,
|
||||
model,
|
||||
posCond,
|
||||
negCond,
|
||||
posCondCollect,
|
||||
negCondCollect,
|
||||
ipAdapterCollect,
|
||||
}: AddRegionsArg): Promise<AddedRegionResult[]> => {
|
||||
const isSDXL = model.base === 'sdxl';
|
||||
const isFLUX = model.base === 'flux';
|
||||
|
||||
const validRegions = regions.filter((rg) => {
|
||||
if (!rg.isEnabled) {
|
||||
return false;
|
||||
}
|
||||
return getRegionalGuidanceWarnings(rg, model).length === 0;
|
||||
});
|
||||
|
||||
const validRegions = regions.filter((rg) => isValidRegion(rg, base));
|
||||
const results: AddedRegionResult[] = [];
|
||||
|
||||
for (const region of validRegions) {
|
||||
@@ -94,20 +104,27 @@ export const addRegions = async (
|
||||
if (region.positivePrompt) {
|
||||
// The main positive conditioning node
|
||||
result.addedPositivePrompt = true;
|
||||
const regionalPosCond = g.addNode(
|
||||
isSDXL
|
||||
? {
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: getPrefixedId('prompt_region_positive_cond'),
|
||||
prompt: region.positivePrompt,
|
||||
style: region.positivePrompt, // TODO: Should we put the positive prompt in both fields?
|
||||
}
|
||||
: {
|
||||
type: 'compel',
|
||||
id: getPrefixedId('prompt_region_positive_cond'),
|
||||
prompt: region.positivePrompt,
|
||||
}
|
||||
);
|
||||
let regionalPosCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>;
|
||||
if (isSDXL) {
|
||||
regionalPosCond = g.addNode({
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: getPrefixedId('prompt_region_positive_cond'),
|
||||
prompt: region.positivePrompt,
|
||||
style: region.positivePrompt, // TODO: Should we put the positive prompt in both fields?
|
||||
});
|
||||
} else if (isFLUX) {
|
||||
regionalPosCond = g.addNode({
|
||||
type: 'flux_text_encoder',
|
||||
id: getPrefixedId('prompt_region_positive_cond'),
|
||||
prompt: region.positivePrompt,
|
||||
});
|
||||
} else {
|
||||
regionalPosCond = g.addNode({
|
||||
type: 'compel',
|
||||
id: getPrefixedId('prompt_region_positive_cond'),
|
||||
prompt: region.positivePrompt,
|
||||
});
|
||||
}
|
||||
// Connect the mask to the conditioning
|
||||
g.addEdge(maskToTensor, 'mask', regionalPosCond, 'mask');
|
||||
// Connect the conditioning to the collector
|
||||
@@ -115,38 +132,55 @@ export const addRegions = async (
|
||||
// Copy the connections to the "global" positive conditioning node to the regional cond
|
||||
if (posCond.type === 'compel') {
|
||||
for (const edge of g.getEdgesTo(posCond, ['clip', 'mask'])) {
|
||||
// Clone the edge, but change the destination node to the regional conditioning node
|
||||
const clone = deepClone(edge);
|
||||
clone.destination.node_id = regionalPosCond.id;
|
||||
g.addEdgeFromObj(clone);
|
||||
}
|
||||
} else if (posCond.type === 'sdxl_compel_prompt') {
|
||||
for (const edge of g.getEdgesTo(posCond, ['clip', 'clip2', 'mask'])) {
|
||||
const clone = deepClone(edge);
|
||||
clone.destination.node_id = regionalPosCond.id;
|
||||
g.addEdgeFromObj(clone);
|
||||
}
|
||||
} else if (posCond.type === 'flux_text_encoder') {
|
||||
for (const edge of g.getEdgesTo(posCond, ['clip', 't5_encoder', 't5_max_seq_len', 'mask'])) {
|
||||
const clone = deepClone(edge);
|
||||
clone.destination.node_id = regionalPosCond.id;
|
||||
g.addEdgeFromObj(clone);
|
||||
}
|
||||
} else {
|
||||
for (const edge of g.getEdgesTo(posCond, ['clip', 'clip2', 'mask'])) {
|
||||
// Clone the edge, but change the destination node to the regional conditioning node
|
||||
const clone = deepClone(edge);
|
||||
clone.destination.node_id = regionalPosCond.id;
|
||||
g.addEdgeFromObj(clone);
|
||||
}
|
||||
assert(false, 'Unsupported positive conditioning node type.');
|
||||
}
|
||||
}
|
||||
|
||||
if (region.negativePrompt) {
|
||||
result.addedNegativePrompt = true;
|
||||
assert(negCond, 'Negative conditioning node is required if there is a negative prompt');
|
||||
assert(negCondCollect, 'Negative conditioning collector is required if there is a negative prompt');
|
||||
|
||||
// The main negative conditioning node
|
||||
const regionalNegCond = g.addNode(
|
||||
isSDXL
|
||||
? {
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: getPrefixedId('prompt_region_negative_cond'),
|
||||
prompt: region.negativePrompt,
|
||||
style: region.negativePrompt,
|
||||
}
|
||||
: {
|
||||
type: 'compel',
|
||||
id: getPrefixedId('prompt_region_negative_cond'),
|
||||
prompt: region.negativePrompt,
|
||||
}
|
||||
);
|
||||
result.addedNegativePrompt = true;
|
||||
let regionalNegCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>;
|
||||
if (isSDXL) {
|
||||
regionalNegCond = g.addNode({
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: getPrefixedId('prompt_region_negative_cond'),
|
||||
prompt: region.negativePrompt,
|
||||
style: region.negativePrompt,
|
||||
});
|
||||
} else if (isFLUX) {
|
||||
regionalNegCond = g.addNode({
|
||||
type: 'flux_text_encoder',
|
||||
id: getPrefixedId('prompt_region_negative_cond'),
|
||||
prompt: region.negativePrompt,
|
||||
});
|
||||
} else {
|
||||
regionalNegCond = g.addNode({
|
||||
type: 'compel',
|
||||
id: getPrefixedId('prompt_region_negative_cond'),
|
||||
prompt: region.negativePrompt,
|
||||
});
|
||||
}
|
||||
|
||||
// Connect the mask to the conditioning
|
||||
g.addEdge(maskToTensor, 'mask', regionalNegCond, 'mask');
|
||||
// Connect the conditioning to the collector
|
||||
@@ -158,17 +192,27 @@ export const addRegions = async (
|
||||
clone.destination.node_id = regionalNegCond.id;
|
||||
g.addEdgeFromObj(clone);
|
||||
}
|
||||
} else {
|
||||
} else if (negCond.type === 'sdxl_compel_prompt') {
|
||||
for (const edge of g.getEdgesTo(negCond, ['clip', 'clip2', 'mask'])) {
|
||||
const clone = deepClone(edge);
|
||||
clone.destination.node_id = regionalNegCond.id;
|
||||
g.addEdgeFromObj(clone);
|
||||
}
|
||||
} else if (negCond.type === 'flux_text_encoder') {
|
||||
for (const edge of g.getEdgesTo(negCond, ['clip', 't5_encoder', 't5_max_seq_len', 'mask'])) {
|
||||
const clone = deepClone(edge);
|
||||
clone.destination.node_id = regionalNegCond.id;
|
||||
g.addEdgeFromObj(clone);
|
||||
}
|
||||
} else {
|
||||
assert(false, 'Unsupported negative conditioning node type.');
|
||||
}
|
||||
}
|
||||
|
||||
// If we are using the "invert" auto-negative setting, we need to add an additional negative conditioning node
|
||||
if (region.autoNegative && region.positivePrompt) {
|
||||
assert(negCondCollect, 'Negative conditioning collector is required if there is an auto-negative setting');
|
||||
|
||||
result.addedAutoNegativePositivePrompt = true;
|
||||
// We re-use the mask image, but invert it when converting to tensor
|
||||
const invertTensorMask = g.addNode({
|
||||
@@ -178,20 +222,27 @@ export const addRegions = async (
|
||||
// Connect the OG mask image to the inverted mask-to-tensor node
|
||||
g.addEdge(maskToTensor, 'mask', invertTensorMask, 'mask');
|
||||
// Create the conditioning node. It's going to be connected to the negative cond collector, but it uses the positive prompt
|
||||
const regionalPosCondInverted = g.addNode(
|
||||
isSDXL
|
||||
? {
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: getPrefixedId('prompt_region_positive_cond_inverted'),
|
||||
prompt: region.positivePrompt,
|
||||
style: region.positivePrompt,
|
||||
}
|
||||
: {
|
||||
type: 'compel',
|
||||
id: getPrefixedId('prompt_region_positive_cond_inverted'),
|
||||
prompt: region.positivePrompt,
|
||||
}
|
||||
);
|
||||
let regionalPosCondInverted: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>;
|
||||
if (isSDXL) {
|
||||
regionalPosCondInverted = g.addNode({
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: getPrefixedId('prompt_region_positive_cond_inverted'),
|
||||
prompt: region.positivePrompt,
|
||||
style: region.positivePrompt,
|
||||
});
|
||||
} else if (isFLUX) {
|
||||
regionalPosCondInverted = g.addNode({
|
||||
type: 'flux_text_encoder',
|
||||
id: getPrefixedId('prompt_region_positive_cond_inverted'),
|
||||
prompt: region.positivePrompt,
|
||||
});
|
||||
} else {
|
||||
regionalPosCondInverted = g.addNode({
|
||||
type: 'compel',
|
||||
id: getPrefixedId('prompt_region_positive_cond_inverted'),
|
||||
prompt: region.positivePrompt,
|
||||
});
|
||||
}
|
||||
// Connect the inverted mask to the conditioning
|
||||
g.addEdge(invertTensorMask, 'mask', regionalPosCondInverted, 'mask');
|
||||
// Connect the conditioning to the negative collector
|
||||
@@ -203,20 +254,26 @@ export const addRegions = async (
|
||||
clone.destination.node_id = regionalPosCondInverted.id;
|
||||
g.addEdgeFromObj(clone);
|
||||
}
|
||||
} else {
|
||||
} else if (posCond.type === 'sdxl_compel_prompt') {
|
||||
for (const edge of g.getEdgesTo(posCond, ['clip', 'clip2', 'mask'])) {
|
||||
const clone = deepClone(edge);
|
||||
clone.destination.node_id = regionalPosCondInverted.id;
|
||||
g.addEdgeFromObj(clone);
|
||||
}
|
||||
} else if (posCond.type === 'flux_text_encoder') {
|
||||
for (const edge of g.getEdgesTo(posCond, ['clip', 't5_encoder', 't5_max_seq_len', 'mask'])) {
|
||||
const clone = deepClone(edge);
|
||||
clone.destination.node_id = regionalPosCondInverted.id;
|
||||
g.addEdgeFromObj(clone);
|
||||
}
|
||||
} else {
|
||||
assert(false, 'Unsupported positive conditioning node type.');
|
||||
}
|
||||
}
|
||||
|
||||
const validRGIPAdapters: RegionalGuidanceReferenceImageState[] = region.referenceImages.filter(({ ipAdapter }) =>
|
||||
isValidIPAdapter(ipAdapter, base)
|
||||
);
|
||||
for (const { id, ipAdapter } of region.referenceImages) {
|
||||
assert(!isFLUX, 'Regional IP adapters are not supported for FLUX.');
|
||||
|
||||
for (const { id, ipAdapter } of validRGIPAdapters) {
|
||||
result.addedIPAdapters++;
|
||||
const { weight, model, clipVisionModel, method, beginEndStepPct, image } = ipAdapter;
|
||||
assert(model, 'IP Adapter model is required');
|
||||
@@ -248,11 +305,3 @@ export const addRegions = async (
|
||||
|
||||
return results;
|
||||
};
|
||||
|
||||
const isValidIPAdapter = (ipAdapter: IPAdapterConfig, base: BaseModelType): boolean => {
|
||||
// Must be have a model that matches the current base and must have a control image
|
||||
const hasModel = Boolean(ipAdapter.model);
|
||||
const modelMatchesBase = ipAdapter.model?.base === base;
|
||||
const hasImage = Boolean(ipAdapter.image);
|
||||
return hasModel && modelMatchesBase && hasImage;
|
||||
};
|
||||
|
||||
@@ -11,6 +11,7 @@ import { addImageToImage } from 'features/nodes/util/graph/generation/addImageTo
|
||||
import { addInpaint } from 'features/nodes/util/graph/generation/addInpaint';
|
||||
import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChecker';
|
||||
import { addOutpaint } from 'features/nodes/util/graph/generation/addOutpaint';
|
||||
import { addRegions } from 'features/nodes/util/graph/generation/addRegions';
|
||||
import { addTextToImage } from 'features/nodes/util/graph/generation/addTextToImage';
|
||||
import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker';
|
||||
import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
@@ -79,7 +80,10 @@ export const buildFLUXGraph = async (
|
||||
id: getPrefixedId('flux_text_encoder'),
|
||||
prompt: positivePrompt,
|
||||
});
|
||||
|
||||
const posCondCollect = g.addNode({
|
||||
type: 'collect',
|
||||
id: getPrefixedId('pos_cond_collect'),
|
||||
});
|
||||
const denoise = g.addNode({
|
||||
type: 'flux_denoise',
|
||||
id: getPrefixedId('flux_denoise'),
|
||||
@@ -104,13 +108,12 @@ export const buildFLUXGraph = async (
|
||||
g.addEdge(modelLoader, 'clip', posCond, 'clip');
|
||||
g.addEdge(modelLoader, 't5_encoder', posCond, 't5_encoder');
|
||||
g.addEdge(modelLoader, 'max_seq_len', posCond, 't5_max_seq_len');
|
||||
g.addEdge(posCond, 'conditioning', posCondCollect, 'item');
|
||||
g.addEdge(posCondCollect, 'collection', denoise, 'positive_text_conditioning');
|
||||
g.addEdge(denoise, 'latents', l2i, 'latents');
|
||||
|
||||
addFLUXLoRAs(state, g, denoise, modelLoader, posCond);
|
||||
|
||||
g.addEdge(posCond, 'conditioning', denoise, 'positive_text_conditioning');
|
||||
|
||||
g.addEdge(denoise, 'latents', l2i, 'latents');
|
||||
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||
assert(modelConfig.base === 'flux');
|
||||
|
||||
@@ -196,31 +199,50 @@ export const buildFLUXGraph = async (
|
||||
type: 'collect',
|
||||
id: getPrefixedId('control_net_collector'),
|
||||
});
|
||||
const controlNetResult = await addControlNets(
|
||||
const controlNetResult = await addControlNets({
|
||||
manager,
|
||||
canvas.controlLayers.entities,
|
||||
entities: canvas.controlLayers.entities,
|
||||
g,
|
||||
canvas.bbox.rect,
|
||||
controlNetCollector,
|
||||
modelConfig.base
|
||||
);
|
||||
rect: canvas.bbox.rect,
|
||||
collector: controlNetCollector,
|
||||
model: modelConfig,
|
||||
});
|
||||
if (controlNetResult.addedControlNets > 0) {
|
||||
g.addEdge(controlNetCollector, 'collection', denoise, 'control');
|
||||
} else {
|
||||
g.deleteNode(controlNetCollector.id);
|
||||
}
|
||||
|
||||
const ipAdapterCollector = g.addNode({
|
||||
const ipAdapterCollect = g.addNode({
|
||||
type: 'collect',
|
||||
id: getPrefixedId('ip_adapter_collector'),
|
||||
});
|
||||
const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollector, modelConfig.base);
|
||||
const ipAdapterResult = addIPAdapters({
|
||||
entities: canvas.referenceImages.entities,
|
||||
g,
|
||||
collector: ipAdapterCollect,
|
||||
model: modelConfig,
|
||||
});
|
||||
|
||||
const totalIPAdaptersAdded = ipAdapterResult.addedIPAdapters;
|
||||
const regionsResult = await addRegions({
|
||||
manager,
|
||||
regions: canvas.regionalGuidance.entities,
|
||||
g,
|
||||
bbox: canvas.bbox.rect,
|
||||
model: modelConfig,
|
||||
posCond,
|
||||
negCond: null,
|
||||
posCondCollect,
|
||||
negCondCollect: null,
|
||||
ipAdapterCollect,
|
||||
});
|
||||
|
||||
const totalIPAdaptersAdded =
|
||||
ipAdapterResult.addedIPAdapters + regionsResult.reduce((acc, r) => acc + r.addedIPAdapters, 0);
|
||||
if (totalIPAdaptersAdded > 0) {
|
||||
g.addEdge(ipAdapterCollector, 'collection', denoise, 'ip_adapter');
|
||||
g.addEdge(ipAdapterCollect, 'collection', denoise, 'ip_adapter');
|
||||
} else {
|
||||
g.deleteNode(ipAdapterCollector.id);
|
||||
g.deleteNode(ipAdapterCollect.id);
|
||||
}
|
||||
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
|
||||
@@ -227,14 +227,14 @@ export const buildSD1Graph = async (
|
||||
type: 'collect',
|
||||
id: getPrefixedId('control_net_collector'),
|
||||
});
|
||||
const controlNetResult = await addControlNets(
|
||||
const controlNetResult = await addControlNets({
|
||||
manager,
|
||||
canvas.controlLayers.entities,
|
||||
entities: canvas.controlLayers.entities,
|
||||
g,
|
||||
canvas.bbox.rect,
|
||||
controlNetCollector,
|
||||
modelConfig.base
|
||||
);
|
||||
rect: canvas.bbox.rect,
|
||||
collector: controlNetCollector,
|
||||
model: modelConfig,
|
||||
});
|
||||
if (controlNetResult.addedControlNets > 0) {
|
||||
g.addEdge(controlNetCollector, 'collection', denoise, 'control');
|
||||
} else {
|
||||
@@ -245,46 +245,50 @@ export const buildSD1Graph = async (
|
||||
type: 'collect',
|
||||
id: getPrefixedId('t2i_adapter_collector'),
|
||||
});
|
||||
const t2iAdapterResult = await addT2IAdapters(
|
||||
const t2iAdapterResult = await addT2IAdapters({
|
||||
manager,
|
||||
canvas.controlLayers.entities,
|
||||
entities: canvas.controlLayers.entities,
|
||||
g,
|
||||
canvas.bbox.rect,
|
||||
t2iAdapterCollector,
|
||||
modelConfig.base
|
||||
);
|
||||
rect: canvas.bbox.rect,
|
||||
collector: t2iAdapterCollector,
|
||||
model: modelConfig,
|
||||
});
|
||||
if (t2iAdapterResult.addedT2IAdapters > 0) {
|
||||
g.addEdge(t2iAdapterCollector, 'collection', denoise, 't2i_adapter');
|
||||
} else {
|
||||
g.deleteNode(t2iAdapterCollector.id);
|
||||
}
|
||||
|
||||
const ipAdapterCollector = g.addNode({
|
||||
const ipAdapterCollect = g.addNode({
|
||||
type: 'collect',
|
||||
id: getPrefixedId('ip_adapter_collector'),
|
||||
});
|
||||
const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollector, modelConfig.base);
|
||||
|
||||
const regionsResult = await addRegions(
|
||||
manager,
|
||||
canvas.regionalGuidance.entities,
|
||||
const ipAdapterResult = addIPAdapters({
|
||||
entities: canvas.referenceImages.entities,
|
||||
g,
|
||||
canvas.bbox.rect,
|
||||
modelConfig.base,
|
||||
denoise,
|
||||
collector: ipAdapterCollect,
|
||||
model: modelConfig,
|
||||
});
|
||||
|
||||
const regionsResult = await addRegions({
|
||||
manager,
|
||||
regions: canvas.regionalGuidance.entities,
|
||||
g,
|
||||
bbox: canvas.bbox.rect,
|
||||
model: modelConfig,
|
||||
posCond,
|
||||
negCond,
|
||||
posCondCollect,
|
||||
negCondCollect,
|
||||
ipAdapterCollector
|
||||
);
|
||||
ipAdapterCollect,
|
||||
});
|
||||
|
||||
const totalIPAdaptersAdded =
|
||||
ipAdapterResult.addedIPAdapters + regionsResult.reduce((acc, r) => acc + r.addedIPAdapters, 0);
|
||||
if (totalIPAdaptersAdded > 0) {
|
||||
g.addEdge(ipAdapterCollector, 'collection', denoise, 'ip_adapter');
|
||||
g.addEdge(ipAdapterCollect, 'collection', denoise, 'ip_adapter');
|
||||
} else {
|
||||
g.deleteNode(ipAdapterCollector.id);
|
||||
g.deleteNode(ipAdapterCollect.id);
|
||||
}
|
||||
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
|
||||
@@ -232,14 +232,14 @@ export const buildSDXLGraph = async (
|
||||
type: 'collect',
|
||||
id: getPrefixedId('control_net_collector'),
|
||||
});
|
||||
const controlNetResult = await addControlNets(
|
||||
const controlNetResult = await addControlNets({
|
||||
manager,
|
||||
canvas.controlLayers.entities,
|
||||
entities: canvas.controlLayers.entities,
|
||||
g,
|
||||
canvas.bbox.rect,
|
||||
controlNetCollector,
|
||||
modelConfig.base
|
||||
);
|
||||
rect: canvas.bbox.rect,
|
||||
collector: controlNetCollector,
|
||||
model: modelConfig,
|
||||
});
|
||||
if (controlNetResult.addedControlNets > 0) {
|
||||
g.addEdge(controlNetCollector, 'collection', denoise, 'control');
|
||||
} else {
|
||||
@@ -250,46 +250,50 @@ export const buildSDXLGraph = async (
|
||||
type: 'collect',
|
||||
id: getPrefixedId('t2i_adapter_collector'),
|
||||
});
|
||||
const t2iAdapterResult = await addT2IAdapters(
|
||||
const t2iAdapterResult = await addT2IAdapters({
|
||||
manager,
|
||||
canvas.controlLayers.entities,
|
||||
entities: canvas.controlLayers.entities,
|
||||
g,
|
||||
canvas.bbox.rect,
|
||||
t2iAdapterCollector,
|
||||
modelConfig.base
|
||||
);
|
||||
rect: canvas.bbox.rect,
|
||||
collector: t2iAdapterCollector,
|
||||
model: modelConfig,
|
||||
});
|
||||
if (t2iAdapterResult.addedT2IAdapters > 0) {
|
||||
g.addEdge(t2iAdapterCollector, 'collection', denoise, 't2i_adapter');
|
||||
} else {
|
||||
g.deleteNode(t2iAdapterCollector.id);
|
||||
}
|
||||
|
||||
const ipAdapterCollector = g.addNode({
|
||||
const ipAdapterCollect = g.addNode({
|
||||
type: 'collect',
|
||||
id: getPrefixedId('ip_adapter_collector'),
|
||||
});
|
||||
const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollector, modelConfig.base);
|
||||
|
||||
const regionsResult = await addRegions(
|
||||
manager,
|
||||
canvas.regionalGuidance.entities,
|
||||
const ipAdapterResult = addIPAdapters({
|
||||
entities: canvas.referenceImages.entities,
|
||||
g,
|
||||
canvas.bbox.rect,
|
||||
modelConfig.base,
|
||||
denoise,
|
||||
collector: ipAdapterCollect,
|
||||
model: modelConfig,
|
||||
});
|
||||
|
||||
const regionsResult = await addRegions({
|
||||
manager,
|
||||
regions: canvas.regionalGuidance.entities,
|
||||
g,
|
||||
bbox: canvas.bbox.rect,
|
||||
model: modelConfig,
|
||||
posCond,
|
||||
negCond,
|
||||
posCondCollect,
|
||||
negCondCollect,
|
||||
ipAdapterCollector
|
||||
);
|
||||
ipAdapterCollect,
|
||||
});
|
||||
|
||||
const totalIPAdaptersAdded =
|
||||
ipAdapterResult.addedIPAdapters + regionsResult.reduce((acc, r) => acc + r.addedIPAdapters, 0);
|
||||
if (totalIPAdaptersAdded > 0) {
|
||||
g.addEdge(ipAdapterCollector, 'collection', denoise, 'ip_adapter');
|
||||
g.addEdge(ipAdapterCollect, 'collection', denoise, 'ip_adapter');
|
||||
} else {
|
||||
g.deleteNode(ipAdapterCollector.id);
|
||||
g.deleteNode(ipAdapterCollect.id);
|
||||
}
|
||||
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
|
||||
@@ -4,13 +4,13 @@ import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiTrashSimpleFill } from 'react-icons/pi';
|
||||
|
||||
import { useClearQueue } from './ClearQueueConfirmationAlertDialog';
|
||||
import { useClearQueueDialog } from './ClearQueueConfirmationAlertDialog';
|
||||
|
||||
type Props = ButtonProps;
|
||||
|
||||
const ClearQueueButton = (props: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const clearQueue = useClearQueue();
|
||||
const clearQueue = useClearQueueDialog();
|
||||
|
||||
return (
|
||||
<>
|
||||
|
||||
@@ -1,51 +1,15 @@
|
||||
import { ConfirmationAlertDialog, Text } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
|
||||
import { buildUseBoolean } from 'common/hooks/useBoolean';
|
||||
import { listCursorChanged, listPriorityChanged } from 'features/queue/store/queueSlice';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useClearQueue } from 'features/queue/hooks/useClearQueue';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useClearQueueMutation, useGetQueueStatusQuery } from 'services/api/endpoints/queue';
|
||||
import { $isConnected } from 'services/events/stores';
|
||||
|
||||
const [useClearQueueConfirmationAlertDialog] = buildUseBoolean(false);
|
||||
|
||||
export const useClearQueue = () => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
export const useClearQueueDialog = () => {
|
||||
const dialog = useClearQueueConfirmationAlertDialog();
|
||||
const { data: queueStatus } = useGetQueueStatusQuery();
|
||||
const isConnected = useStore($isConnected);
|
||||
const [trigger, { isLoading }] = useClearQueueMutation({
|
||||
fixedCacheKey: 'clearQueue',
|
||||
});
|
||||
|
||||
const clearQueue = useCallback(async () => {
|
||||
if (!queueStatus?.queue.total) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
await trigger().unwrap();
|
||||
toast({
|
||||
id: 'QUEUE_CLEAR_SUCCEEDED',
|
||||
title: t('queue.clearSucceeded'),
|
||||
status: 'success',
|
||||
});
|
||||
dispatch(listCursorChanged(undefined));
|
||||
dispatch(listPriorityChanged(undefined));
|
||||
} catch {
|
||||
toast({
|
||||
id: 'QUEUE_CLEAR_FAILED',
|
||||
title: t('queue.clearFailed'),
|
||||
status: 'error',
|
||||
});
|
||||
}
|
||||
}, [queueStatus?.queue.total, trigger, dispatch, t]);
|
||||
|
||||
const isDisabled = useMemo(() => !isConnected || !queueStatus?.queue.total, [isConnected, queueStatus?.queue.total]);
|
||||
const { clearQueue, isLoading, isDisabled, queueStatus } = useClearQueue();
|
||||
|
||||
return {
|
||||
clearQueue,
|
||||
@@ -61,7 +25,7 @@ export const useClearQueue = () => {
|
||||
export const ClearQueueConfirmationsAlertDialog = memo(() => {
|
||||
useAssertSingleton('ClearQueueConfirmationsAlertDialog');
|
||||
const { t } = useTranslation();
|
||||
const clearQueue = useClearQueue();
|
||||
const clearQueue = useClearQueueDialog();
|
||||
|
||||
return (
|
||||
<ConfirmationAlertDialog
|
||||
|
||||
@@ -4,11 +4,11 @@ import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiTrashSimpleBold, PiXBold } from 'react-icons/pi';
|
||||
|
||||
import { useClearQueue } from './ClearQueueConfirmationAlertDialog';
|
||||
import { useClearQueueDialog } from './ClearQueueConfirmationAlertDialog';
|
||||
|
||||
export const ClearQueueIconButton = memo((_) => {
|
||||
const { t } = useTranslation();
|
||||
const clearQueue = useClearQueue();
|
||||
const clearQueue = useClearQueueDialog();
|
||||
const cancelCurrentQueueItem = useCancelCurrentQueueItem();
|
||||
|
||||
// Show the single item clear button when shift is pressed
|
||||
|
||||
@@ -147,8 +147,6 @@ const UpscaleTabTooltipContent = memo(({ prepend = false }: { prepend?: boolean
|
||||
<ReasonsList reasons={reasons} />
|
||||
</>
|
||||
)}
|
||||
<StyledDivider />
|
||||
<AddingToText />
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
@@ -180,8 +178,6 @@ const WorkflowsTabTooltipContent = memo(({ prepend = false }: { prepend?: boolea
|
||||
<ReasonsList reasons={reasons} />
|
||||
</>
|
||||
)}
|
||||
<StyledDivider />
|
||||
<AddingToText />
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import { IconButton, Menu, MenuButton, MenuGroup, MenuItem, MenuList } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { SessionMenuItems } from 'common/components/SessionMenuItems';
|
||||
import { useClearQueue } from 'features/queue/components/ClearQueueConfirmationAlertDialog';
|
||||
import { useClearQueueDialog } from 'features/queue/components/ClearQueueConfirmationAlertDialog';
|
||||
import { QueueCountBadge } from 'features/queue/components/QueueCountBadge';
|
||||
import { useCancelCurrentQueueItem } from 'features/queue/hooks/useCancelCurrentQueueItem';
|
||||
import { usePauseProcessor } from 'features/queue/hooks/usePauseProcessor';
|
||||
import { useResumeProcessor } from 'features/queue/hooks/useResumeProcessor';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
@@ -17,7 +18,8 @@ export const QueueActionsMenuButton = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const isPauseEnabled = useFeatureStatus('pauseQueue');
|
||||
const isResumeEnabled = useFeatureStatus('resumeQueue');
|
||||
const clearQueue = useClearQueue();
|
||||
const cancelCurrent = useCancelCurrentQueueItem();
|
||||
const clearQueue = useClearQueueDialog();
|
||||
const {
|
||||
resumeProcessor,
|
||||
isLoading: isLoadingResumeProcessor,
|
||||
@@ -44,9 +46,9 @@ export const QueueActionsMenuButton = memo(() => {
|
||||
<MenuItem
|
||||
isDestructive
|
||||
icon={<PiXBold />}
|
||||
onClick={clearQueue.openDialog}
|
||||
isLoading={clearQueue.isLoading}
|
||||
isDisabled={clearQueue.isDisabled}
|
||||
onClick={cancelCurrent.cancelQueueItem}
|
||||
isLoading={cancelCurrent.isLoading}
|
||||
isDisabled={cancelCurrent.isDisabled}
|
||||
>
|
||||
{t('queue.cancelTooltip')}
|
||||
</MenuItem>
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { listCursorChanged, listPriorityChanged } from 'features/queue/store/queueSlice';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useClearQueueMutation, useGetQueueStatusQuery } from 'services/api/endpoints/queue';
|
||||
import { $isConnected } from 'services/events/stores';
|
||||
|
||||
export const useClearQueue = () => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const { data: queueStatus } = useGetQueueStatusQuery();
|
||||
const isConnected = useStore($isConnected);
|
||||
const [trigger, { isLoading }] = useClearQueueMutation({
|
||||
fixedCacheKey: 'clearQueue',
|
||||
});
|
||||
|
||||
const clearQueue = useCallback(async () => {
|
||||
if (!queueStatus?.queue.total) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
await trigger().unwrap();
|
||||
toast({
|
||||
id: 'QUEUE_CLEAR_SUCCEEDED',
|
||||
title: t('queue.clearSucceeded'),
|
||||
status: 'success',
|
||||
});
|
||||
dispatch(listCursorChanged(undefined));
|
||||
dispatch(listPriorityChanged(undefined));
|
||||
} catch {
|
||||
toast({
|
||||
id: 'QUEUE_CLEAR_FAILED',
|
||||
title: t('queue.clearFailed'),
|
||||
status: 'error',
|
||||
});
|
||||
}
|
||||
}, [queueStatus?.queue.total, trigger, dispatch, t]);
|
||||
|
||||
const isDisabled = useMemo(() => !isConnected || !queueStatus?.queue.total, [isConnected, queueStatus?.queue.total]);
|
||||
|
||||
return {
|
||||
clearQueue,
|
||||
isLoading,
|
||||
queueStatus,
|
||||
isDisabled,
|
||||
};
|
||||
};
|
||||
@@ -4,6 +4,13 @@ import type { ParamsState } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
import type { CanvasState } from 'features/controlLayers/store/types';
|
||||
import {
|
||||
getControlLayerWarnings,
|
||||
getGlobalReferenceImageWarnings,
|
||||
getInpaintMaskWarnings,
|
||||
getRasterLayerWarnings,
|
||||
getRegionalGuidanceWarnings,
|
||||
} from 'features/controlLayers/store/validators';
|
||||
import type { DynamicPromptsState } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
|
||||
import { selectDynamicPromptsSlice } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
|
||||
import { getShouldProcessPrompt } from 'features/dynamicPrompts/util/getShouldProcessPrompt';
|
||||
@@ -278,17 +285,10 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: {
|
||||
const layerNumber = i + 1;
|
||||
const layerType = i18n.t(LAYER_TYPE_TO_TKEY['control_layer']);
|
||||
const prefix = `${layerLiteral} #${layerNumber} (${layerType})`;
|
||||
const problems: string[] = [];
|
||||
// Must have model
|
||||
if (!controlLayer.controlAdapter.model) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.controlAdapterNoModelSelected'));
|
||||
}
|
||||
// Model base must match
|
||||
if (controlLayer.controlAdapter.model?.base !== model?.base) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.controlAdapterIncompatibleBaseModel'));
|
||||
}
|
||||
const problems = getControlLayerWarnings(controlLayer, model);
|
||||
|
||||
if (problems.length) {
|
||||
const content = upperFirst(problems.join(', '));
|
||||
const content = upperFirst(problems.map((p) => i18n.t(p)).join(', '));
|
||||
reasons.push({ prefix, content });
|
||||
}
|
||||
});
|
||||
@@ -300,23 +300,10 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: {
|
||||
const layerNumber = i + 1;
|
||||
const layerType = i18n.t(LAYER_TYPE_TO_TKEY[entity.type]);
|
||||
const prefix = `${layerLiteral} #${layerNumber} (${layerType})`;
|
||||
const problems: string[] = [];
|
||||
|
||||
// Must have model
|
||||
if (!entity.ipAdapter.model) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoModelSelected'));
|
||||
}
|
||||
// Model base must match
|
||||
if (entity.ipAdapter.model?.base !== model?.base) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.ipAdapterIncompatibleBaseModel'));
|
||||
}
|
||||
// Must have an image
|
||||
if (!entity.ipAdapter.image) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoImageSelected'));
|
||||
}
|
||||
const problems = getGlobalReferenceImageWarnings(entity, model);
|
||||
|
||||
if (problems.length) {
|
||||
const content = upperFirst(problems.join(', '));
|
||||
const content = upperFirst(problems.map((p) => i18n.t(p)).join(', '));
|
||||
reasons.push({ prefix, content });
|
||||
}
|
||||
});
|
||||
@@ -328,32 +315,10 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: {
|
||||
const layerNumber = i + 1;
|
||||
const layerType = i18n.t(LAYER_TYPE_TO_TKEY[entity.type]);
|
||||
const prefix = `${layerLiteral} #${layerNumber} (${layerType})`;
|
||||
const problems: string[] = [];
|
||||
// Must have a region
|
||||
if (entity.objects.length === 0) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.rgNoRegion'));
|
||||
}
|
||||
// Must have at least 1 prompt or IP Adapter
|
||||
if (entity.positivePrompt === null && entity.negativePrompt === null && entity.referenceImages.length === 0) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.rgNoPromptsOrIPAdapters'));
|
||||
}
|
||||
entity.referenceImages.forEach(({ ipAdapter }) => {
|
||||
// Must have model
|
||||
if (!ipAdapter.model) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoModelSelected'));
|
||||
}
|
||||
// Model base must match
|
||||
if (ipAdapter.model?.base !== model?.base) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.ipAdapterIncompatibleBaseModel'));
|
||||
}
|
||||
// Must have an image
|
||||
if (!ipAdapter.image) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoImageSelected'));
|
||||
}
|
||||
});
|
||||
const problems = getRegionalGuidanceWarnings(entity, model);
|
||||
|
||||
if (problems.length) {
|
||||
const content = upperFirst(problems.join(', '));
|
||||
const content = upperFirst(problems.map((p) => i18n.t(p)).join(', '));
|
||||
reasons.push({ prefix, content });
|
||||
}
|
||||
});
|
||||
@@ -365,10 +330,25 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: {
|
||||
const layerNumber = i + 1;
|
||||
const layerType = i18n.t(LAYER_TYPE_TO_TKEY[entity.type]);
|
||||
const prefix = `${layerLiteral} #${layerNumber} (${layerType})`;
|
||||
const problems: string[] = [];
|
||||
const problems = getRasterLayerWarnings(entity, model);
|
||||
|
||||
if (problems.length) {
|
||||
const content = upperFirst(problems.join(', '));
|
||||
const content = upperFirst(problems.map((p) => i18n.t(p)).join(', '));
|
||||
reasons.push({ prefix, content });
|
||||
}
|
||||
});
|
||||
|
||||
canvas.inpaintMasks.entities
|
||||
.filter((entity) => entity.isEnabled)
|
||||
.forEach((entity, i) => {
|
||||
const layerLiteral = i18n.t('controlLayers.layer_one');
|
||||
const layerNumber = i + 1;
|
||||
const layerType = i18n.t(LAYER_TYPE_TO_TKEY[entity.type]);
|
||||
const prefix = `${layerLiteral} #${layerNumber} (${layerType})`;
|
||||
const problems = getInpaintMaskWarnings(entity, model);
|
||||
|
||||
if (problems.length) {
|
||||
const content = upperFirst(problems.map((p) => i18n.t(p)).join(', '));
|
||||
reasons.push({ prefix, content });
|
||||
}
|
||||
});
|
||||
|
||||
@@ -31,6 +31,7 @@ const optionsObject: Record<Language, string> = {
|
||||
sv: 'Svenska',
|
||||
tr: 'Türkçe',
|
||||
ua: 'Украї́нська',
|
||||
vi: 'tiếng Việt',
|
||||
zh_CN: '简体中文',
|
||||
zh_Hant: '漢語',
|
||||
};
|
||||
|
||||
@@ -22,6 +22,7 @@ const zLanguage = z.enum([
|
||||
'sv',
|
||||
'tr',
|
||||
'ua',
|
||||
'vi',
|
||||
'zh_CN',
|
||||
'zh_Hant',
|
||||
]);
|
||||
|
||||
@@ -3,7 +3,7 @@ import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { ToolChooser } from 'features/controlLayers/components/Tool/ToolChooser';
|
||||
import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
|
||||
import { useClearQueue } from 'features/queue/components/ClearQueueConfirmationAlertDialog';
|
||||
import { useClearQueueDialog } from 'features/queue/components/ClearQueueConfirmationAlertDialog';
|
||||
import { InvokeButtonTooltip } from 'features/queue/components/InvokeButtonTooltip/InvokeButtonTooltip';
|
||||
import { useCancelCurrentQueueItem } from 'features/queue/hooks/useCancelCurrentQueueItem';
|
||||
import { useInvoke } from 'features/queue/hooks/useInvoke';
|
||||
@@ -31,7 +31,7 @@ const FloatingSidePanelButtons = (props: Props) => {
|
||||
const shift = useShiftModifier();
|
||||
const tab = useAppSelector(selectActiveTab);
|
||||
const imageViewer = useImageViewer();
|
||||
const clearQueue = useClearQueue();
|
||||
const clearQueue = useClearQueueDialog();
|
||||
const { data: queueStatus } = useGetQueueStatusQuery();
|
||||
const cancelCurrent = useCancelCurrentQueueItem();
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import type { StartQueryActionCreatorOptions } from '@reduxjs/toolkit/dist/query/core/buildInitiate';
|
||||
import { $authToken } from 'app/store/nanostores/authToken';
|
||||
import { getStore } from 'app/store/nanostores/store';
|
||||
import type { BoardId } from 'features/gallery/store/types';
|
||||
import { ASSETS_CATEGORIES, IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
||||
@@ -624,3 +625,20 @@ export const uploadImages = async (args: UploadImageArg[]): Promise<ImageDTO[]>
|
||||
);
|
||||
return results.filter((r): r is PromiseFulfilledResult<ImageDTO> => r.status === 'fulfilled').map((r) => r.value);
|
||||
};
|
||||
|
||||
/**
|
||||
* Convert an ImageDTO to a File by downloading the image from the server.
|
||||
* @param imageDTO The image to download and convert to a File
|
||||
*/
|
||||
export const imageDTOToFile = async (imageDTO: ImageDTO): Promise<File> => {
|
||||
const init: RequestInit = {};
|
||||
const authToken = $authToken.get();
|
||||
if (authToken) {
|
||||
init.headers = { Authorization: `Bearer ${authToken}` };
|
||||
}
|
||||
const res = await fetch(imageDTO.image_url, init);
|
||||
const blob = await res.blob();
|
||||
// Create a new file with the same name, which we will upload
|
||||
const file = new File([blob], `copy_of_${imageDTO.image_name}`, { type: 'image/png' });
|
||||
return file;
|
||||
};
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -56,7 +56,7 @@ dependencies = [
|
||||
"torchmetrics",
|
||||
"torchsde",
|
||||
"torchvision",
|
||||
"transformers==4.41.1",
|
||||
"transformers==4.46.3",
|
||||
|
||||
# Core application dependencies, pinned for reproducible builds.
|
||||
"fastapi-events==0.11.1",
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_only_full_load import (
|
||||
CachedModelOnlyFullLoad,
|
||||
)
|
||||
from tests.backend.model_manager.load.model_cache.dummy_module import DummyModule
|
||||
|
||||
parameterize_mps_and_cuda = pytest.mark.parametrize(
|
||||
("device"),
|
||||
[
|
||||
pytest.param(
|
||||
"mps", marks=pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS is not available.")
|
||||
),
|
||||
pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available.")),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_total_bytes(device: str):
|
||||
model = DummyModule()
|
||||
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
|
||||
assert cached_model.total_bytes() == 100
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_is_in_vram(device: str):
|
||||
model = DummyModule()
|
||||
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
|
||||
assert not cached_model.is_in_vram()
|
||||
|
||||
cached_model.full_load_to_vram()
|
||||
assert cached_model.is_in_vram()
|
||||
|
||||
cached_model.full_unload_from_vram()
|
||||
assert not cached_model.is_in_vram()
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_full_load_and_unload(device: str):
|
||||
model = DummyModule()
|
||||
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
|
||||
assert cached_model.full_load_to_vram() == 100
|
||||
assert cached_model.is_in_vram()
|
||||
assert all(p.device.type == device for p in cached_model.model.parameters())
|
||||
|
||||
assert cached_model.full_unload_from_vram() == 100
|
||||
assert not cached_model.is_in_vram()
|
||||
assert all(p.device.type == "cpu" for p in cached_model.model.parameters())
|
||||
@@ -0,0 +1,174 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import (
|
||||
CachedModelWithPartialLoad,
|
||||
)
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
|
||||
from tests.backend.model_manager.load.model_cache.dummy_module import DummyModule
|
||||
|
||||
parameterize_mps_and_cuda = pytest.mark.parametrize(
|
||||
("device"),
|
||||
[
|
||||
pytest.param(
|
||||
"mps", marks=pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS is not available.")
|
||||
),
|
||||
pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available.")),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_total_bytes(device: str):
|
||||
if device == "cuda" and not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is not available.")
|
||||
if device == "mps" and not torch.backends.mps.is_available():
|
||||
pytest.skip("MPS is not available.")
|
||||
|
||||
model = DummyModule()
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
linear_numel = 10 * 10 + 10
|
||||
assert cached_model.total_bytes() == linear_numel * 4 * 2
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_cur_vram_bytes(device: str):
|
||||
model = DummyModule()
|
||||
# Model starts in CPU memory.
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
# Full load the model into VRAM.
|
||||
cached_model.full_load_to_vram()
|
||||
assert cached_model.cur_vram_bytes() > 0
|
||||
assert cached_model.cur_vram_bytes() == cached_model.total_bytes()
|
||||
assert all(p.device.type == device for p in model.parameters())
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_partial_load(device: str):
|
||||
model = DummyModule()
|
||||
# Model starts in CPU memory.
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
# Partially load the model into VRAM.
|
||||
target_vram_bytes = int(model_total_bytes * 0.6)
|
||||
loaded_bytes = cached_model.partial_load_to_vram(target_vram_bytes)
|
||||
assert loaded_bytes > 0
|
||||
assert loaded_bytes < model_total_bytes
|
||||
assert loaded_bytes == cached_model.cur_vram_bytes()
|
||||
assert loaded_bytes == sum(calc_tensor_size(p) for p in model.parameters() if p.device.type == device)
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_partial_unload(device: str):
|
||||
model = DummyModule()
|
||||
# Model starts in CPU memory.
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
# Full load the model into VRAM.
|
||||
cached_model.full_load_to_vram()
|
||||
assert cached_model.cur_vram_bytes() == model_total_bytes
|
||||
|
||||
# Partially unload the model from VRAM.
|
||||
bytes_to_free = int(model_total_bytes * 0.4)
|
||||
freed_bytes = cached_model.partial_unload_from_vram(bytes_to_free)
|
||||
assert freed_bytes >= bytes_to_free
|
||||
assert freed_bytes < model_total_bytes
|
||||
assert freed_bytes == model_total_bytes - cached_model.cur_vram_bytes()
|
||||
assert freed_bytes == sum(calc_tensor_size(p) for p in model.parameters() if p.device.type == "cpu")
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_full_load(device: str):
|
||||
model = DummyModule()
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
|
||||
# Model starts in CPU memory.
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
# Full load the model into VRAM.
|
||||
loaded_bytes = cached_model.full_load_to_vram()
|
||||
assert loaded_bytes > 0
|
||||
assert loaded_bytes == model_total_bytes
|
||||
assert loaded_bytes == cached_model.cur_vram_bytes()
|
||||
assert all(p.device.type == device for p in model.parameters())
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_full_load_from_partial(device: str):
|
||||
model = DummyModule()
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
|
||||
# Model starts in CPU memory.
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
# Partially load the model into VRAM.
|
||||
target_vram_bytes = int(model_total_bytes * 0.6)
|
||||
loaded_bytes = cached_model.partial_load_to_vram(target_vram_bytes)
|
||||
assert loaded_bytes > 0
|
||||
assert loaded_bytes < model_total_bytes
|
||||
assert loaded_bytes == cached_model.cur_vram_bytes()
|
||||
|
||||
# Full load the rest of the model into VRAM.
|
||||
loaded_bytes_2 = cached_model.full_load_to_vram()
|
||||
assert loaded_bytes_2 > 0
|
||||
assert loaded_bytes_2 < model_total_bytes
|
||||
assert loaded_bytes + loaded_bytes_2 == cached_model.cur_vram_bytes()
|
||||
assert loaded_bytes + loaded_bytes_2 == model_total_bytes
|
||||
assert all(p.device.type == device for p in model.parameters())
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_full_unload_from_partial(device: str):
|
||||
model = DummyModule()
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
|
||||
# Model starts in CPU memory.
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
# Partially load the model into VRAM.
|
||||
target_vram_bytes = int(model_total_bytes * 0.6)
|
||||
loaded_bytes = cached_model.partial_load_to_vram(target_vram_bytes)
|
||||
assert loaded_bytes > 0
|
||||
assert loaded_bytes < model_total_bytes
|
||||
assert loaded_bytes == cached_model.cur_vram_bytes()
|
||||
|
||||
# Full unload the model from VRAM.
|
||||
unloaded_bytes = cached_model.full_unload_from_vram()
|
||||
assert unloaded_bytes > 0
|
||||
assert unloaded_bytes == loaded_bytes
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
assert all(p.device.type == "cpu" for p in model.parameters())
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_get_cpu_state_dict(device: str):
|
||||
model = DummyModule()
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
|
||||
# Model starts in CPU memory.
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
# The CPU state dict can be accessed and has the expected properties.
|
||||
cpu_state_dict = cached_model.get_cpu_state_dict()
|
||||
assert cpu_state_dict is not None
|
||||
assert len(cpu_state_dict) == len(model.state_dict())
|
||||
assert all(p.device.type == "cpu" for p in cpu_state_dict.values())
|
||||
|
||||
# Full load the model into VRAM.
|
||||
cached_model.full_load_to_vram()
|
||||
assert cached_model.cur_vram_bytes() == cached_model.total_bytes()
|
||||
|
||||
# The CPU state dict is still available, and still on the CPU.
|
||||
cpu_state_dict = cached_model.get_cpu_state_dict()
|
||||
assert cpu_state_dict is not None
|
||||
assert len(cpu_state_dict) == len(model.state_dict())
|
||||
assert all(p.device.type == "cpu" for p in cpu_state_dict.values())
|
||||
13
tests/backend/model_manager/load/model_cache/dummy_module.py
Normal file
13
tests/backend/model_manager/load/model_cache/dummy_module.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import torch
|
||||
|
||||
|
||||
class DummyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(10, 10)
|
||||
self.linear2 = torch.nn.Linear(10, 10)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.linear1(x)
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
@@ -0,0 +1,50 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_function_autocast_context import (
|
||||
TorchFunctionAutocastDeviceContext,
|
||||
add_autocast_to_module_forward,
|
||||
)
|
||||
from tests.backend.model_manager.load.model_cache.dummy_module import DummyModule
|
||||
|
||||
|
||||
def test_torch_function_autocast_device_context():
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is not available.")
|
||||
|
||||
model = DummyModule()
|
||||
# Model parameters should start off on the CPU.
|
||||
assert all(p.device.type == "cpu" for p in model.parameters())
|
||||
|
||||
with TorchFunctionAutocastDeviceContext(to_device=torch.device("cuda")):
|
||||
x = torch.randn(10, 10, device="cuda")
|
||||
y = model(x)
|
||||
|
||||
# The model output should be on the GPU.
|
||||
assert y.device.type == "cuda"
|
||||
|
||||
# The model parameters should still be on the CPU.
|
||||
assert all(p.device.type == "cpu" for p in model.parameters())
|
||||
|
||||
|
||||
def test_add_autocast_to_module_forward():
|
||||
model = DummyModule()
|
||||
assert all(p.device.type == "cpu" for p in model.parameters())
|
||||
|
||||
add_autocast_to_module_forward(model, torch.device("cuda"))
|
||||
# After adding autocast, the model parameters should still be on the CPU.
|
||||
assert all(p.device.type == "cpu" for p in model.parameters())
|
||||
|
||||
x = torch.randn(10, 10, device="cuda")
|
||||
y = model(x)
|
||||
|
||||
# The model output should be on the GPU.
|
||||
assert y.device.type == "cuda"
|
||||
|
||||
# The model parameters should still be on the CPU.
|
||||
assert all(p.device.type == "cpu" for p in model.parameters())
|
||||
|
||||
# The autocast context should automatically be disabled after the model forward call completes.
|
||||
# So, attempting to perform an operation with comflicting devices should raise an error.
|
||||
with pytest.raises(RuntimeError):
|
||||
_ = torch.randn(10, device="cuda") * torch.randn(10, device="cpu")
|
||||
@@ -25,7 +25,7 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelVariantType,
|
||||
VAEDiffusersConfig,
|
||||
)
|
||||
from invokeai.backend.model_manager.load import ModelCache
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from tests.backend.model_manager.model_metadata.metadata_examples import (
|
||||
HFTestLoraMetadata,
|
||||
|
||||
Reference in New Issue
Block a user