mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-19 18:08:44 -05:00
Compare commits
74 Commits
ryan/fix-d
...
ryan/model
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
987393853c | ||
|
|
91c5af1b95 | ||
|
|
5c67dd507a | ||
|
|
2ff928ec17 | ||
|
|
4327bbe77e | ||
|
|
ad1c0d37ef | ||
|
|
9708d87946 | ||
|
|
3ad44f7850 | ||
|
|
9a482981b2 | ||
|
|
6b02362b12 | ||
|
|
8fec4ec91c | ||
|
|
693e421970 | ||
|
|
dc14104bc8 | ||
|
|
f286a1d1f3 | ||
|
|
9dc86b2b71 | ||
|
|
2cab689b79 | ||
|
|
f8c7adddd0 | ||
|
|
17da1d92e9 | ||
|
|
1cc57a4854 | ||
|
|
3993fae331 | ||
|
|
1446526d55 | ||
|
|
62c024e725 | ||
|
|
1e92bb4e94 | ||
|
|
db6398fdf6 | ||
|
|
ebd73a2ac2 | ||
|
|
8ee95cab00 | ||
|
|
d1184201a8 | ||
|
|
5887891654 | ||
|
|
765ca4e004 | ||
|
|
159b00a490 | ||
|
|
3fbf6f2d2a | ||
|
|
931fca7cd1 | ||
|
|
db84a3a5d4 | ||
|
|
ca8313e805 | ||
|
|
df849035ee | ||
|
|
8d97fe69ca | ||
|
|
9044e53a9b | ||
|
|
6012b0f912 | ||
|
|
bb0ed5dc8a | ||
|
|
021552fd81 | ||
|
|
be73dbba92 | ||
|
|
db9c0cad7c | ||
|
|
54b7f9a063 | ||
|
|
7d488a5352 | ||
|
|
4d7667f63d | ||
|
|
08704ee8ec | ||
|
|
5910892c33 | ||
|
|
46a09d9e90 | ||
|
|
df0c7d73f3 | ||
|
|
3905c97e32 | ||
|
|
0be796a808 | ||
|
|
7dd33b0f39 | ||
|
|
484aaf1595 | ||
|
|
c276b60af9 | ||
|
|
5d8dd6e26e | ||
|
|
5bca68d873 | ||
|
|
64364e7911 | ||
|
|
6565cea039 | ||
|
|
3ebd8d6c07 | ||
|
|
e970185161 | ||
|
|
fa5653cdf7 | ||
|
|
9a7b000995 | ||
|
|
3a27242838 | ||
|
|
b54463d294 | ||
|
|
faee79dc95 | ||
|
|
e01f66b026 | ||
|
|
53abdde242 | ||
|
|
94c088300f | ||
|
|
3741a6f5e0 | ||
|
|
2c23b8414c | ||
|
|
20356c0746 | ||
|
|
bad1149504 | ||
|
|
fda7aaa7ca | ||
|
|
85c616fa34 |
@@ -1364,7 +1364,6 @@ the in-memory loaded model:
|
||||
|----------------|-----------------|------------------|
|
||||
| `config` | AnyModelConfig | A copy of the model's configuration record for retrieving base type, etc. |
|
||||
| `model` | AnyModel | The instantiated model (details below) |
|
||||
| `locker` | ModelLockerBase | A context manager that mediates the movement of the model into VRAM |
|
||||
|
||||
### get_model_by_key(key, [submodel]) -> LoadedModel
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ This project is a combined effort of dedicated people from across the world. [C
|
||||
|
||||
## Code of Conduct
|
||||
|
||||
The InvokeAI community is a welcoming place, and we want your help in maintaining that. Please review our [Code of Conduct](https://github.com/invoke-ai/InvokeAI/blob/main/CODE_OF_CONDUCT.md) to learn more - it's essential to maintaining a respectful and inclusive environment.
|
||||
The InvokeAI community is a welcoming place, and we want your help in maintaining that. Please review our [Code of Conduct](https://github.com/invoke-ai/InvokeAI/blob/main/docs/CODE_OF_CONDUCT.md) to learn more - it's essential to maintaining a respectful and inclusive environment.
|
||||
|
||||
By making a contribution to this project, you certify that:
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import CacheStats
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
|
||||
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
|
||||
@@ -110,7 +110,7 @@ async def cancel_by_batch_ids(
|
||||
@session_queue_router.put(
|
||||
"/{queue_id}/cancel_by_destination",
|
||||
operation_id="cancel_by_destination",
|
||||
responses={200: {"model": CancelByBatchIDsResult}},
|
||||
responses={200: {"model": CancelByDestinationResult}},
|
||||
)
|
||||
async def cancel_by_destination(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
|
||||
@@ -250,6 +250,11 @@ class FluxConditioningField(BaseModel):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||
mask: Optional[TensorField] = Field(
|
||||
default=None,
|
||||
description="The mask associated with this conditioning tensor. Excluded regions should be set to False, "
|
||||
"included regions should be set to True.",
|
||||
)
|
||||
|
||||
|
||||
class SD3ConditioningField(BaseModel):
|
||||
|
||||
@@ -30,6 +30,7 @@ from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlN
|
||||
from invokeai.backend.flux.denoise import denoise
|
||||
from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension
|
||||
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
|
||||
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
|
||||
from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import XlabsIpAdapterFlux
|
||||
@@ -42,6 +43,7 @@ from invokeai.backend.flux.sampling_utils import (
|
||||
pack,
|
||||
unpack,
|
||||
)
|
||||
from invokeai.backend.flux.text_conditioning import FluxTextConditioning
|
||||
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
||||
@@ -56,7 +58,7 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
title="FLUX Denoise",
|
||||
tags=["image", "flux"],
|
||||
category="image",
|
||||
version="3.2.1",
|
||||
version="3.2.2",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
@@ -87,10 +89,10 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
input=Input.Connection,
|
||||
title="Transformer",
|
||||
)
|
||||
positive_text_conditioning: FluxConditioningField = InputField(
|
||||
positive_text_conditioning: FluxConditioningField | list[FluxConditioningField] = InputField(
|
||||
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||
)
|
||||
negative_text_conditioning: FluxConditioningField | None = InputField(
|
||||
negative_text_conditioning: FluxConditioningField | list[FluxConditioningField] | None = InputField(
|
||||
default=None,
|
||||
description="Negative conditioning tensor. Can be None if cfg_scale is 1.0.",
|
||||
input=Input.Connection,
|
||||
@@ -139,36 +141,12 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
name = context.tensors.save(tensor=latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||
|
||||
def _load_text_conditioning(
|
||||
self, context: InvocationContext, conditioning_name: str, dtype: torch.dtype
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Load the conditioning data.
|
||||
cond_data = context.conditioning.load(conditioning_name)
|
||||
assert len(cond_data.conditionings) == 1
|
||||
flux_conditioning = cond_data.conditionings[0]
|
||||
assert isinstance(flux_conditioning, FLUXConditioningInfo)
|
||||
flux_conditioning = flux_conditioning.to(dtype=dtype)
|
||||
t5_embeddings = flux_conditioning.t5_embeds
|
||||
clip_embeddings = flux_conditioning.clip_embeds
|
||||
return t5_embeddings, clip_embeddings
|
||||
|
||||
def _run_diffusion(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
):
|
||||
inference_dtype = torch.bfloat16
|
||||
|
||||
# Load the conditioning data.
|
||||
pos_t5_embeddings, pos_clip_embeddings = self._load_text_conditioning(
|
||||
context, self.positive_text_conditioning.conditioning_name, inference_dtype
|
||||
)
|
||||
neg_t5_embeddings: torch.Tensor | None = None
|
||||
neg_clip_embeddings: torch.Tensor | None = None
|
||||
if self.negative_text_conditioning is not None:
|
||||
neg_t5_embeddings, neg_clip_embeddings = self._load_text_conditioning(
|
||||
context, self.negative_text_conditioning.conditioning_name, inference_dtype
|
||||
)
|
||||
|
||||
# Load the input latents, if provided.
|
||||
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
|
||||
if init_latents is not None:
|
||||
@@ -183,15 +161,45 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
dtype=inference_dtype,
|
||||
seed=self.seed,
|
||||
)
|
||||
b, _c, latent_h, latent_w = noise.shape
|
||||
packed_h = latent_h // 2
|
||||
packed_w = latent_w // 2
|
||||
|
||||
# Load the conditioning data.
|
||||
pos_text_conditionings = self._load_text_conditioning(
|
||||
context=context,
|
||||
cond_field=self.positive_text_conditioning,
|
||||
packed_height=packed_h,
|
||||
packed_width=packed_w,
|
||||
dtype=inference_dtype,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
)
|
||||
neg_text_conditionings: list[FluxTextConditioning] | None = None
|
||||
if self.negative_text_conditioning is not None:
|
||||
neg_text_conditionings = self._load_text_conditioning(
|
||||
context=context,
|
||||
cond_field=self.negative_text_conditioning,
|
||||
packed_height=packed_h,
|
||||
packed_width=packed_w,
|
||||
dtype=inference_dtype,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
)
|
||||
pos_regional_prompting_extension = RegionalPromptingExtension.from_text_conditioning(
|
||||
pos_text_conditionings, img_seq_len=packed_h * packed_w
|
||||
)
|
||||
neg_regional_prompting_extension = (
|
||||
RegionalPromptingExtension.from_text_conditioning(neg_text_conditionings, img_seq_len=packed_h * packed_w)
|
||||
if neg_text_conditionings
|
||||
else None
|
||||
)
|
||||
|
||||
transformer_info = context.models.load(self.transformer.transformer)
|
||||
is_schnell = "schnell" in transformer_info.config.config_path
|
||||
|
||||
# Calculate the timestep schedule.
|
||||
image_seq_len = noise.shape[-1] * noise.shape[-2] // 4
|
||||
timesteps = get_schedule(
|
||||
num_steps=self.num_steps,
|
||||
image_seq_len=image_seq_len,
|
||||
image_seq_len=packed_h * packed_w,
|
||||
shift=not is_schnell,
|
||||
)
|
||||
|
||||
@@ -228,28 +236,17 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
inpaint_mask = self._prep_inpaint_mask(context, x)
|
||||
|
||||
b, _c, latent_h, latent_w = x.shape
|
||||
img_ids = generate_img_ids(h=latent_h, w=latent_w, batch_size=b, device=x.device, dtype=x.dtype)
|
||||
|
||||
pos_bs, pos_t5_seq_len, _ = pos_t5_embeddings.shape
|
||||
pos_txt_ids = torch.zeros(
|
||||
pos_bs, pos_t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device()
|
||||
)
|
||||
neg_txt_ids: torch.Tensor | None = None
|
||||
if neg_t5_embeddings is not None:
|
||||
neg_bs, neg_t5_seq_len, _ = neg_t5_embeddings.shape
|
||||
neg_txt_ids = torch.zeros(
|
||||
neg_bs, neg_t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device()
|
||||
)
|
||||
|
||||
# Pack all latent tensors.
|
||||
init_latents = pack(init_latents) if init_latents is not None else None
|
||||
inpaint_mask = pack(inpaint_mask) if inpaint_mask is not None else None
|
||||
noise = pack(noise)
|
||||
x = pack(x)
|
||||
|
||||
# Now that we have 'packed' the latent tensors, verify that we calculated the image_seq_len correctly.
|
||||
assert image_seq_len == x.shape[1]
|
||||
# Now that we have 'packed' the latent tensors, verify that we calculated the image_seq_len, packed_h, and
|
||||
# packed_w correctly.
|
||||
assert packed_h * packed_w == x.shape[1]
|
||||
|
||||
# Prepare inpaint extension.
|
||||
inpaint_extension: InpaintExtension | None = None
|
||||
@@ -338,12 +335,8 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
model=transformer,
|
||||
img=x,
|
||||
img_ids=img_ids,
|
||||
txt=pos_t5_embeddings,
|
||||
txt_ids=pos_txt_ids,
|
||||
vec=pos_clip_embeddings,
|
||||
neg_txt=neg_t5_embeddings,
|
||||
neg_txt_ids=neg_txt_ids,
|
||||
neg_vec=neg_clip_embeddings,
|
||||
pos_regional_prompting_extension=pos_regional_prompting_extension,
|
||||
neg_regional_prompting_extension=neg_regional_prompting_extension,
|
||||
timesteps=timesteps,
|
||||
step_callback=self._build_step_callback(context),
|
||||
guidance=self.guidance,
|
||||
@@ -357,6 +350,43 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
x = unpack(x.float(), self.height, self.width)
|
||||
return x
|
||||
|
||||
def _load_text_conditioning(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
cond_field: FluxConditioningField | list[FluxConditioningField],
|
||||
packed_height: int,
|
||||
packed_width: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> list[FluxTextConditioning]:
|
||||
"""Load text conditioning data from a FluxConditioningField or a list of FluxConditioningFields."""
|
||||
# Normalize to a list of FluxConditioningFields.
|
||||
cond_list = [cond_field] if isinstance(cond_field, FluxConditioningField) else cond_field
|
||||
|
||||
text_conditionings: list[FluxTextConditioning] = []
|
||||
for cond_field in cond_list:
|
||||
# Load the text embeddings.
|
||||
cond_data = context.conditioning.load(cond_field.conditioning_name)
|
||||
assert len(cond_data.conditionings) == 1
|
||||
flux_conditioning = cond_data.conditionings[0]
|
||||
assert isinstance(flux_conditioning, FLUXConditioningInfo)
|
||||
flux_conditioning = flux_conditioning.to(dtype=dtype, device=device)
|
||||
t5_embeddings = flux_conditioning.t5_embeds
|
||||
clip_embeddings = flux_conditioning.clip_embeds
|
||||
|
||||
# Load the mask, if provided.
|
||||
mask: Optional[torch.Tensor] = None
|
||||
if cond_field.mask is not None:
|
||||
mask = context.tensors.load(cond_field.mask.tensor_name)
|
||||
mask = mask.to(device=device)
|
||||
mask = RegionalPromptingExtension.preprocess_regional_prompt_mask(
|
||||
mask, packed_height, packed_width, dtype, device
|
||||
)
|
||||
|
||||
text_conditionings.append(FluxTextConditioning(t5_embeddings, clip_embeddings, mask))
|
||||
|
||||
return text_conditionings
|
||||
|
||||
@classmethod
|
||||
def prep_cfg_scale(
|
||||
cls, cfg_scale: float | list[float], timesteps: list[float], cfg_scale_start_step: int, cfg_scale_end_step: int
|
||||
|
||||
@@ -1,11 +1,18 @@
|
||||
from contextlib import ExitStack
|
||||
from typing import Iterator, Literal, Tuple
|
||||
from typing import Iterator, Literal, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, UIComponent
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
FluxConditioningField,
|
||||
Input,
|
||||
InputField,
|
||||
TensorField,
|
||||
UIComponent,
|
||||
)
|
||||
from invokeai.app.invocations.model import CLIPField, T5EncoderField
|
||||
from invokeai.app.invocations.primitives import FluxConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
@@ -22,7 +29,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Condit
|
||||
title="FLUX Text Encoding",
|
||||
tags=["prompt", "conditioning", "flux"],
|
||||
category="conditioning",
|
||||
version="1.1.0",
|
||||
version="1.1.1",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxTextEncoderInvocation(BaseInvocation):
|
||||
@@ -41,9 +48,9 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
t5_max_seq_len: Literal[256, 512] = InputField(
|
||||
description="Max sequence length for the T5 encoder. Expected to be 256 for FLUX schnell models and 512 for FLUX dev models."
|
||||
)
|
||||
prompt: str = InputField(
|
||||
description="Text prompt to encode.",
|
||||
ui_component=UIComponent.Textarea,
|
||||
prompt: str = InputField(description="Text prompt to encode.", ui_component=UIComponent.Textarea)
|
||||
mask: Optional[TensorField] = InputField(
|
||||
default=None, description="A mask defining the region that this conditioning prompt applies to."
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -57,7 +64,9 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
return FluxConditioningOutput.build(conditioning_name)
|
||||
return FluxConditioningOutput(
|
||||
conditioning=FluxConditioningField(conditioning_name=conditioning_name, mask=self.mask)
|
||||
)
|
||||
|
||||
def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
|
||||
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
|
||||
|
||||
@@ -20,7 +20,7 @@ from invokeai.app.services.invocation_stats.invocation_stats_common import (
|
||||
NodeExecutionStatsSummary,
|
||||
)
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.backend.model_manager.load.model_cache import CacheStats
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
|
||||
|
||||
# Size of 1GB in bytes.
|
||||
GB = 2**30
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Callable, Optional
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||
from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
|
||||
|
||||
|
||||
class ModelLoadServiceBase(ABC):
|
||||
@@ -24,7 +24,7 @@ class ModelLoadServiceBase(ABC):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
||||
def ram_cache(self) -> ModelCache:
|
||||
"""Return the RAM cache used by this loader."""
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -18,7 +18,7 @@ from invokeai.backend.model_manager.load import (
|
||||
ModelLoaderRegistry,
|
||||
ModelLoaderRegistryBase,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
|
||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
@@ -30,7 +30,7 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
def __init__(
|
||||
self,
|
||||
app_config: InvokeAIAppConfig,
|
||||
ram_cache: ModelCacheBase[AnyModel],
|
||||
ram_cache: ModelCache,
|
||||
registry: Optional[Type[ModelLoaderRegistryBase]] = ModelLoaderRegistry,
|
||||
):
|
||||
"""Initialize the model load service."""
|
||||
@@ -45,7 +45,7 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
self._invoker = invoker
|
||||
|
||||
@property
|
||||
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
||||
def ram_cache(self) -> ModelCache:
|
||||
"""Return the RAM cache used by this loader."""
|
||||
return self._ram_cache
|
||||
|
||||
@@ -78,9 +78,8 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
|
||||
) -> LoadedModelWithoutConfig:
|
||||
cache_key = str(model_path)
|
||||
ram_cache = self.ram_cache
|
||||
try:
|
||||
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))
|
||||
return LoadedModelWithoutConfig(cache_record=self._ram_cache.get(key=cache_key), cache=self._ram_cache)
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
@@ -109,5 +108,5 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
)
|
||||
assert loader is not None
|
||||
raw_model = loader(model_path)
|
||||
ram_cache.put(key=cache_key, model=raw_model)
|
||||
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))
|
||||
self._ram_cache.put(key=cache_key, model=raw_model)
|
||||
return LoadedModelWithoutConfig(cache_record=self._ram_cache.get(key=cache_key), cache=self._ram_cache)
|
||||
|
||||
@@ -16,7 +16,8 @@ from invokeai.app.services.model_load.model_load_base import ModelLoadServiceBas
|
||||
from invokeai.app.services.model_load.model_load_default import ModelLoadService
|
||||
from invokeai.app.services.model_manager.model_manager_base import ModelManagerServiceBase
|
||||
from invokeai.app.services.model_records.model_records_base import ModelRecordServiceBase
|
||||
from invokeai.backend.model_manager.load import ModelCache, ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import einops
|
||||
import torch
|
||||
|
||||
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
|
||||
from invokeai.backend.flux.math import attention
|
||||
from invokeai.backend.flux.modules.layers import DoubleStreamBlock
|
||||
from invokeai.backend.flux.modules.layers import DoubleStreamBlock, SingleStreamBlock
|
||||
|
||||
|
||||
class CustomDoubleStreamBlockProcessor:
|
||||
@@ -13,7 +14,12 @@ class CustomDoubleStreamBlockProcessor:
|
||||
|
||||
@staticmethod
|
||||
def _double_stream_block_forward(
|
||||
block: DoubleStreamBlock, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, pe: torch.Tensor
|
||||
block: DoubleStreamBlock,
|
||||
img: torch.Tensor,
|
||||
txt: torch.Tensor,
|
||||
vec: torch.Tensor,
|
||||
pe: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""This function is a direct copy of DoubleStreamBlock.forward(), but it returns some of the intermediate
|
||||
values.
|
||||
@@ -40,7 +46,7 @@ class CustomDoubleStreamBlockProcessor:
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
|
||||
attn = attention(q, k, v, pe=pe)
|
||||
attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
|
||||
# calculate the img bloks
|
||||
@@ -63,11 +69,15 @@ class CustomDoubleStreamBlockProcessor:
|
||||
vec: torch.Tensor,
|
||||
pe: torch.Tensor,
|
||||
ip_adapter_extensions: list[XLabsIPAdapterExtension],
|
||||
regional_prompting_extension: RegionalPromptingExtension,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""A custom implementation of DoubleStreamBlock.forward() with additional features:
|
||||
- IP-Adapter support
|
||||
"""
|
||||
img, txt, img_q = CustomDoubleStreamBlockProcessor._double_stream_block_forward(block, img, txt, vec, pe)
|
||||
attn_mask = regional_prompting_extension.get_double_stream_attn_mask(block_index)
|
||||
img, txt, img_q = CustomDoubleStreamBlockProcessor._double_stream_block_forward(
|
||||
block, img, txt, vec, pe, attn_mask=attn_mask
|
||||
)
|
||||
|
||||
# Apply IP-Adapter conditioning.
|
||||
for ip_adapter_extension in ip_adapter_extensions:
|
||||
@@ -81,3 +91,48 @@ class CustomDoubleStreamBlockProcessor:
|
||||
)
|
||||
|
||||
return img, txt
|
||||
|
||||
|
||||
class CustomSingleStreamBlockProcessor:
|
||||
"""A class containing a custom implementation of SingleStreamBlock.forward() with additional features (masking,
|
||||
etc.)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _single_stream_block_forward(
|
||||
block: SingleStreamBlock,
|
||||
x: torch.Tensor,
|
||||
vec: torch.Tensor,
|
||||
pe: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""This function is a direct copy of SingleStreamBlock.forward()."""
|
||||
mod, _ = block.modulation(vec)
|
||||
x_mod = (1 + mod.scale) * block.pre_norm(x) + mod.shift
|
||||
qkv, mlp = torch.split(block.linear1(x_mod), [3 * block.hidden_size, block.mlp_hidden_dim], dim=-1)
|
||||
|
||||
q, k, v = einops.rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=block.num_heads)
|
||||
q, k = block.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = block.linear2(torch.cat((attn, block.mlp_act(mlp)), 2))
|
||||
return x + mod.gate * output
|
||||
|
||||
@staticmethod
|
||||
def custom_single_block_forward(
|
||||
timestep_index: int,
|
||||
total_num_timesteps: int,
|
||||
block_index: int,
|
||||
block: SingleStreamBlock,
|
||||
img: torch.Tensor,
|
||||
vec: torch.Tensor,
|
||||
pe: torch.Tensor,
|
||||
regional_prompting_extension: RegionalPromptingExtension,
|
||||
) -> torch.Tensor:
|
||||
"""A custom implementation of SingleStreamBlock.forward() with additional features:
|
||||
- Masking
|
||||
"""
|
||||
attn_mask = regional_prompting_extension.get_single_stream_attn_mask(block_index)
|
||||
return CustomSingleStreamBlockProcessor._single_stream_block_forward(block, img, vec, pe, attn_mask=attn_mask)
|
||||
|
||||
@@ -7,6 +7,7 @@ from tqdm import tqdm
|
||||
from invokeai.backend.flux.controlnet.controlnet_flux_output import ControlNetFluxOutput, sum_controlnet_flux_outputs
|
||||
from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension
|
||||
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
|
||||
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
|
||||
from invokeai.backend.flux.model import Flux
|
||||
@@ -18,14 +19,8 @@ def denoise(
|
||||
# model input
|
||||
img: torch.Tensor,
|
||||
img_ids: torch.Tensor,
|
||||
# positive text conditioning
|
||||
txt: torch.Tensor,
|
||||
txt_ids: torch.Tensor,
|
||||
vec: torch.Tensor,
|
||||
# negative text conditioning
|
||||
neg_txt: torch.Tensor | None,
|
||||
neg_txt_ids: torch.Tensor | None,
|
||||
neg_vec: torch.Tensor | None,
|
||||
pos_regional_prompting_extension: RegionalPromptingExtension,
|
||||
neg_regional_prompting_extension: RegionalPromptingExtension | None,
|
||||
# sampling parameters
|
||||
timesteps: list[float],
|
||||
step_callback: Callable[[PipelineIntermediateState], None],
|
||||
@@ -61,9 +56,9 @@ def denoise(
|
||||
total_num_timesteps=total_steps,
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
|
||||
txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
|
||||
y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
)
|
||||
@@ -78,9 +73,9 @@ def denoise(
|
||||
pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
|
||||
txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
|
||||
y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
timestep_index=step_index,
|
||||
@@ -88,6 +83,7 @@ def denoise(
|
||||
controlnet_double_block_residuals=merged_controlnet_residuals.double_block_residuals,
|
||||
controlnet_single_block_residuals=merged_controlnet_residuals.single_block_residuals,
|
||||
ip_adapter_extensions=pos_ip_adapter_extensions,
|
||||
regional_prompting_extension=pos_regional_prompting_extension,
|
||||
)
|
||||
|
||||
step_cfg_scale = cfg_scale[step_index]
|
||||
@@ -97,15 +93,15 @@ def denoise(
|
||||
# TODO(ryand): Add option to run positive and negative predictions in a single batch for better performance
|
||||
# on systems with sufficient VRAM.
|
||||
|
||||
if neg_txt is None or neg_txt_ids is None or neg_vec is None:
|
||||
if neg_regional_prompting_extension is None:
|
||||
raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.")
|
||||
|
||||
neg_pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=neg_txt,
|
||||
txt_ids=neg_txt_ids,
|
||||
y=neg_vec,
|
||||
txt=neg_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
|
||||
txt_ids=neg_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
|
||||
y=neg_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
timestep_index=step_index,
|
||||
@@ -113,6 +109,7 @@ def denoise(
|
||||
controlnet_double_block_residuals=None,
|
||||
controlnet_single_block_residuals=None,
|
||||
ip_adapter_extensions=neg_ip_adapter_extensions,
|
||||
regional_prompting_extension=neg_regional_prompting_extension,
|
||||
)
|
||||
pred = neg_pred + step_cfg_scale * (pred - neg_pred)
|
||||
|
||||
|
||||
276
invokeai/backend/flux/extensions/regional_prompting_extension.py
Normal file
276
invokeai/backend/flux/extensions/regional_prompting_extension.py
Normal file
@@ -0,0 +1,276 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
from invokeai.backend.flux.text_conditioning import FluxRegionalTextConditioning, FluxTextConditioning
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Range
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.mask import to_standard_float_mask
|
||||
|
||||
|
||||
class RegionalPromptingExtension:
|
||||
"""A class for managing regional prompting with FLUX.
|
||||
|
||||
This implementation is inspired by https://arxiv.org/pdf/2411.02395 (though there are significant differences).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
regional_text_conditioning: FluxRegionalTextConditioning,
|
||||
restricted_attn_mask: torch.Tensor | None = None,
|
||||
):
|
||||
self.regional_text_conditioning = regional_text_conditioning
|
||||
self.restricted_attn_mask = restricted_attn_mask
|
||||
|
||||
def get_double_stream_attn_mask(self, block_index: int) -> torch.Tensor | None:
|
||||
order = [self.restricted_attn_mask, None]
|
||||
return order[block_index % len(order)]
|
||||
|
||||
def get_single_stream_attn_mask(self, block_index: int) -> torch.Tensor | None:
|
||||
order = [self.restricted_attn_mask, None]
|
||||
return order[block_index % len(order)]
|
||||
|
||||
@classmethod
|
||||
def from_text_conditioning(cls, text_conditioning: list[FluxTextConditioning], img_seq_len: int):
|
||||
"""Create a RegionalPromptingExtension from a list of text conditionings.
|
||||
|
||||
Args:
|
||||
text_conditioning (list[FluxTextConditioning]): The text conditionings to use for regional prompting.
|
||||
img_seq_len (int): The image sequence length (i.e. packed_height * packed_width).
|
||||
"""
|
||||
regional_text_conditioning = cls._concat_regional_text_conditioning(text_conditioning)
|
||||
attn_mask_with_restricted_img_self_attn = cls._prepare_restricted_attn_mask(
|
||||
regional_text_conditioning, img_seq_len
|
||||
)
|
||||
return cls(
|
||||
regional_text_conditioning=regional_text_conditioning,
|
||||
restricted_attn_mask=attn_mask_with_restricted_img_self_attn,
|
||||
)
|
||||
|
||||
# Keeping _prepare_unrestricted_attn_mask for reference as an alternative masking strategy:
|
||||
#
|
||||
# @classmethod
|
||||
# def _prepare_unrestricted_attn_mask(
|
||||
# cls,
|
||||
# regional_text_conditioning: FluxRegionalTextConditioning,
|
||||
# img_seq_len: int,
|
||||
# ) -> torch.Tensor:
|
||||
# """Prepare an 'unrestricted' attention mask. In this context, 'unrestricted' means that:
|
||||
# - img self-attention is not masked.
|
||||
# - img regions attend to both txt within their own region and to global prompts.
|
||||
# """
|
||||
# device = TorchDevice.choose_torch_device()
|
||||
|
||||
# # Infer txt_seq_len from the t5_embeddings tensor.
|
||||
# txt_seq_len = regional_text_conditioning.t5_embeddings.shape[1]
|
||||
|
||||
# # In the attention blocks, the txt seq and img seq are concatenated and then attention is applied.
|
||||
# # Concatenation happens in the following order: [txt_seq, img_seq].
|
||||
# # There are 4 portions of the attention mask to consider as we prepare it:
|
||||
# # 1. txt attends to itself
|
||||
# # 2. txt attends to corresponding regional img
|
||||
# # 3. regional img attends to corresponding txt
|
||||
# # 4. regional img attends to itself
|
||||
|
||||
# # Initialize empty attention mask.
|
||||
# regional_attention_mask = torch.zeros(
|
||||
# (txt_seq_len + img_seq_len, txt_seq_len + img_seq_len), device=device, dtype=torch.float16
|
||||
# )
|
||||
|
||||
# for image_mask, t5_embedding_range in zip(
|
||||
# regional_text_conditioning.image_masks, regional_text_conditioning.t5_embedding_ranges, strict=True
|
||||
# ):
|
||||
# # 1. txt attends to itself
|
||||
# regional_attention_mask[
|
||||
# t5_embedding_range.start : t5_embedding_range.end, t5_embedding_range.start : t5_embedding_range.end
|
||||
# ] = 1.0
|
||||
|
||||
# # 2. txt attends to corresponding regional img
|
||||
# # Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired.
|
||||
# fill_value = image_mask.view(1, img_seq_len) if image_mask is not None else 1.0
|
||||
# regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = fill_value
|
||||
|
||||
# # 3. regional img attends to corresponding txt
|
||||
# # Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired.
|
||||
# fill_value = image_mask.view(img_seq_len, 1) if image_mask is not None else 1.0
|
||||
# regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = fill_value
|
||||
|
||||
# # 4. regional img attends to itself
|
||||
# # Allow unrestricted img self attention.
|
||||
# regional_attention_mask[txt_seq_len:, txt_seq_len:] = 1.0
|
||||
|
||||
# # Convert attention mask to boolean.
|
||||
# regional_attention_mask = regional_attention_mask > 0.5
|
||||
|
||||
# return regional_attention_mask
|
||||
|
||||
@classmethod
|
||||
def _prepare_restricted_attn_mask(
|
||||
cls,
|
||||
regional_text_conditioning: FluxRegionalTextConditioning,
|
||||
img_seq_len: int,
|
||||
) -> torch.Tensor | None:
|
||||
"""Prepare a 'restricted' attention mask. In this context, 'restricted' means that:
|
||||
- img self-attention is only allowed within regions.
|
||||
- img regions only attend to txt within their own region, not to global prompts.
|
||||
"""
|
||||
# Identify background region. I.e. the region that is not covered by any region masks.
|
||||
background_region_mask: None | torch.Tensor = None
|
||||
for image_mask in regional_text_conditioning.image_masks:
|
||||
if image_mask is not None:
|
||||
if background_region_mask is None:
|
||||
background_region_mask = torch.ones_like(image_mask)
|
||||
background_region_mask *= 1 - image_mask
|
||||
|
||||
if background_region_mask is None:
|
||||
# There are no region masks, short-circuit and return None.
|
||||
# TODO(ryand): We could restrict txt-txt attention across multiple global prompts, but this would
|
||||
# is a rare use case and would make the logic here significantly more complicated.
|
||||
return None
|
||||
|
||||
device = TorchDevice.choose_torch_device()
|
||||
|
||||
# Infer txt_seq_len from the t5_embeddings tensor.
|
||||
txt_seq_len = regional_text_conditioning.t5_embeddings.shape[1]
|
||||
|
||||
# In the attention blocks, the txt seq and img seq are concatenated and then attention is applied.
|
||||
# Concatenation happens in the following order: [txt_seq, img_seq].
|
||||
# There are 4 portions of the attention mask to consider as we prepare it:
|
||||
# 1. txt attends to itself
|
||||
# 2. txt attends to corresponding regional img
|
||||
# 3. regional img attends to corresponding txt
|
||||
# 4. regional img attends to itself
|
||||
|
||||
# Initialize empty attention mask.
|
||||
regional_attention_mask = torch.zeros(
|
||||
(txt_seq_len + img_seq_len, txt_seq_len + img_seq_len), device=device, dtype=torch.float16
|
||||
)
|
||||
|
||||
for image_mask, t5_embedding_range in zip(
|
||||
regional_text_conditioning.image_masks, regional_text_conditioning.t5_embedding_ranges, strict=True
|
||||
):
|
||||
# 1. txt attends to itself
|
||||
regional_attention_mask[
|
||||
t5_embedding_range.start : t5_embedding_range.end, t5_embedding_range.start : t5_embedding_range.end
|
||||
] = 1.0
|
||||
|
||||
if image_mask is not None:
|
||||
# 2. txt attends to corresponding regional img
|
||||
# Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired.
|
||||
regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = (
|
||||
image_mask.view(1, img_seq_len)
|
||||
)
|
||||
|
||||
# 3. regional img attends to corresponding txt
|
||||
# Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired.
|
||||
regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = (
|
||||
image_mask.view(img_seq_len, 1)
|
||||
)
|
||||
|
||||
# 4. regional img attends to itself
|
||||
image_mask = image_mask.view(img_seq_len, 1)
|
||||
regional_attention_mask[txt_seq_len:, txt_seq_len:] += image_mask @ image_mask.T
|
||||
else:
|
||||
# We don't allow attention between non-background image regions and global prompts. This helps to ensure
|
||||
# that regions focus on their local prompts. We do, however, allow attention between background regions
|
||||
# and global prompts. If we didn't do this, then the background regions would not attend to any txt
|
||||
# embeddings, which we found experimentally to cause artifacts.
|
||||
|
||||
# 2. global txt attends to background region
|
||||
# Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired.
|
||||
regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = (
|
||||
background_region_mask.view(1, img_seq_len)
|
||||
)
|
||||
|
||||
# 3. background region attends to global txt
|
||||
# Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired.
|
||||
regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = (
|
||||
background_region_mask.view(img_seq_len, 1)
|
||||
)
|
||||
|
||||
# Allow background regions to attend to themselves.
|
||||
regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(img_seq_len, 1)
|
||||
regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(1, img_seq_len)
|
||||
|
||||
# Convert attention mask to boolean.
|
||||
regional_attention_mask = regional_attention_mask > 0.5
|
||||
|
||||
return regional_attention_mask
|
||||
|
||||
@classmethod
|
||||
def _concat_regional_text_conditioning(
|
||||
cls,
|
||||
text_conditionings: list[FluxTextConditioning],
|
||||
) -> FluxRegionalTextConditioning:
|
||||
"""Concatenate regional text conditioning data into a single conditioning tensor (with associated masks)."""
|
||||
concat_t5_embeddings: list[torch.Tensor] = []
|
||||
concat_t5_embedding_ranges: list[Range] = []
|
||||
image_masks: list[torch.Tensor | None] = []
|
||||
|
||||
# Choose global CLIP embedding.
|
||||
# Use the first global prompt's CLIP embedding as the global CLIP embedding. If there is no global prompt, use
|
||||
# the first prompt's CLIP embedding.
|
||||
global_clip_embedding: torch.Tensor = text_conditionings[0].clip_embeddings
|
||||
for text_conditioning in text_conditionings:
|
||||
if text_conditioning.mask is None:
|
||||
global_clip_embedding = text_conditioning.clip_embeddings
|
||||
break
|
||||
|
||||
cur_t5_embedding_len = 0
|
||||
for text_conditioning in text_conditionings:
|
||||
concat_t5_embeddings.append(text_conditioning.t5_embeddings)
|
||||
|
||||
concat_t5_embedding_ranges.append(
|
||||
Range(start=cur_t5_embedding_len, end=cur_t5_embedding_len + text_conditioning.t5_embeddings.shape[1])
|
||||
)
|
||||
|
||||
image_masks.append(text_conditioning.mask)
|
||||
|
||||
cur_t5_embedding_len += text_conditioning.t5_embeddings.shape[1]
|
||||
|
||||
t5_embeddings = torch.cat(concat_t5_embeddings, dim=1)
|
||||
|
||||
# Initialize the txt_ids tensor.
|
||||
pos_bs, pos_t5_seq_len, _ = t5_embeddings.shape
|
||||
t5_txt_ids = torch.zeros(
|
||||
pos_bs, pos_t5_seq_len, 3, dtype=t5_embeddings.dtype, device=TorchDevice.choose_torch_device()
|
||||
)
|
||||
|
||||
return FluxRegionalTextConditioning(
|
||||
t5_embeddings=t5_embeddings,
|
||||
clip_embeddings=global_clip_embedding,
|
||||
t5_txt_ids=t5_txt_ids,
|
||||
image_masks=image_masks,
|
||||
t5_embedding_ranges=concat_t5_embedding_ranges,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def preprocess_regional_prompt_mask(
|
||||
mask: Optional[torch.Tensor], packed_height: int, packed_width: int, dtype: torch.dtype, device: torch.device
|
||||
) -> torch.Tensor:
|
||||
"""Preprocess a regional prompt mask to match the target height and width.
|
||||
If mask is None, returns a mask of all ones with the target height and width.
|
||||
If mask is not None, resizes the mask to the target height and width using 'nearest' interpolation.
|
||||
|
||||
packed_height and packed_width are the target height and width of the mask in the 'packed' latent space.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The processed mask. shape: (1, 1, packed_height * packed_width).
|
||||
"""
|
||||
|
||||
if mask is None:
|
||||
return torch.ones((1, 1, packed_height * packed_width), dtype=dtype, device=device)
|
||||
|
||||
mask = to_standard_float_mask(mask, out_dtype=dtype)
|
||||
|
||||
tf = torchvision.transforms.Resize(
|
||||
(packed_height, packed_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
|
||||
)
|
||||
|
||||
# Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w).
|
||||
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
|
||||
resized_mask = tf(mask)
|
||||
|
||||
# Flatten the height and width dimensions into a single image_seq_len dimension.
|
||||
return resized_mask.flatten(start_dim=2)
|
||||
@@ -5,10 +5,10 @@ from einops import rearrange
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attn_mask: Tensor | None = None) -> Tensor:
|
||||
q, k = apply_rope(q, k, pe)
|
||||
|
||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
x = rearrange(x, "B H L D -> B L (H D)")
|
||||
|
||||
return x
|
||||
@@ -24,12 +24,12 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
out = torch.einsum("...n,d->...nd", pos, omega)
|
||||
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
||||
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
||||
return out.float()
|
||||
return out.to(dtype=pos.dtype, device=pos.device)
|
||||
|
||||
|
||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
|
||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_ = xq.view(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.view(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
return xq_out.view(*xq.shape), xk_out.view(*xk.shape)
|
||||
|
||||
@@ -5,7 +5,11 @@ from dataclasses import dataclass
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from invokeai.backend.flux.custom_block_processor import CustomDoubleStreamBlockProcessor
|
||||
from invokeai.backend.flux.custom_block_processor import (
|
||||
CustomDoubleStreamBlockProcessor,
|
||||
CustomSingleStreamBlockProcessor,
|
||||
)
|
||||
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
|
||||
from invokeai.backend.flux.modules.layers import (
|
||||
DoubleStreamBlock,
|
||||
@@ -95,6 +99,7 @@ class Flux(nn.Module):
|
||||
controlnet_double_block_residuals: list[Tensor] | None,
|
||||
controlnet_single_block_residuals: list[Tensor] | None,
|
||||
ip_adapter_extensions: list[XLabsIPAdapterExtension],
|
||||
regional_prompting_extension: RegionalPromptingExtension,
|
||||
) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
@@ -117,7 +122,6 @@ class Flux(nn.Module):
|
||||
assert len(controlnet_double_block_residuals) == len(self.double_blocks)
|
||||
for block_index, block in enumerate(self.double_blocks):
|
||||
assert isinstance(block, DoubleStreamBlock)
|
||||
|
||||
img, txt = CustomDoubleStreamBlockProcessor.custom_double_block_forward(
|
||||
timestep_index=timestep_index,
|
||||
total_num_timesteps=total_num_timesteps,
|
||||
@@ -128,6 +132,7 @@ class Flux(nn.Module):
|
||||
vec=vec,
|
||||
pe=pe,
|
||||
ip_adapter_extensions=ip_adapter_extensions,
|
||||
regional_prompting_extension=regional_prompting_extension,
|
||||
)
|
||||
|
||||
if controlnet_double_block_residuals is not None:
|
||||
@@ -140,7 +145,17 @@ class Flux(nn.Module):
|
||||
assert len(controlnet_single_block_residuals) == len(self.single_blocks)
|
||||
|
||||
for block_index, block in enumerate(self.single_blocks):
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
assert isinstance(block, SingleStreamBlock)
|
||||
img = CustomSingleStreamBlockProcessor.custom_single_block_forward(
|
||||
timestep_index=timestep_index,
|
||||
total_num_timesteps=total_num_timesteps,
|
||||
block_index=block_index,
|
||||
block=block,
|
||||
img=img,
|
||||
vec=vec,
|
||||
pe=pe,
|
||||
regional_prompting_extension=regional_prompting_extension,
|
||||
)
|
||||
|
||||
if controlnet_single_block_residuals is not None:
|
||||
img[:, txt.shape[1] :, ...] += controlnet_single_block_residuals[block_index]
|
||||
|
||||
@@ -66,10 +66,7 @@ class RMSNorm(torch.nn.Module):
|
||||
self.scale = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
x_dtype = x.dtype
|
||||
x = x.float()
|
||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
||||
return (x * rrms).to(dtype=x_dtype) * self.scale
|
||||
return torch.nn.functional.rms_norm(x, self.scale.shape, self.scale, eps=1e-6)
|
||||
|
||||
|
||||
class QKNorm(torch.nn.Module):
|
||||
|
||||
36
invokeai/backend/flux/text_conditioning.py
Normal file
36
invokeai/backend/flux/text_conditioning.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Range
|
||||
|
||||
|
||||
@dataclass
|
||||
class FluxTextConditioning:
|
||||
t5_embeddings: torch.Tensor
|
||||
clip_embeddings: torch.Tensor
|
||||
# If mask is None, the prompt is a global prompt.
|
||||
mask: torch.Tensor | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FluxRegionalTextConditioning:
|
||||
# Concatenated text embeddings.
|
||||
# Shape: (1, concatenated_txt_seq_len, 4096)
|
||||
t5_embeddings: torch.Tensor
|
||||
# Shape: (1, concatenated_txt_seq_len, 3)
|
||||
t5_txt_ids: torch.Tensor
|
||||
|
||||
# Global CLIP embeddings.
|
||||
# Shape: (1, 768)
|
||||
clip_embeddings: torch.Tensor
|
||||
|
||||
# A binary mask indicating the regions of the image that the prompt should be applied to. If None, the prompt is a
|
||||
# global prompt.
|
||||
# image_masks[i] is the mask for the ith prompt.
|
||||
# image_masks[i] has shape (1, image_seq_len) and dtype torch.bool.
|
||||
image_masks: list[torch.Tensor | None]
|
||||
|
||||
# List of ranges that represent the embedding ranges for each mask.
|
||||
# t5_embedding_ranges[i] contains the range of the t5 embeddings that correspond to image_masks[i].
|
||||
t5_embedding_ranges: list[Range]
|
||||
0
invokeai/backend/model_cache_v2/__init__.py
Normal file
0
invokeai/backend/model_cache_v2/__init__.py
Normal file
105
invokeai/backend/model_cache_v2/cached_model_v2.py
Normal file
105
invokeai/backend/model_cache_v2/cached_model_v2.py
Normal file
@@ -0,0 +1,105 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_cache_v2.torch_module_overrides import CustomLinear, inject_custom_layers_into_module
|
||||
|
||||
|
||||
class CachedModelV2:
|
||||
"""A wrapper around a PyTorch model to handle partial loads and unloads between the CPU and the compute device.
|
||||
|
||||
Note: "VRAM" is used throughout this class to refer to the memory on the compute device. It could be CUDA memory,
|
||||
MPS memory, etc.
|
||||
"""
|
||||
|
||||
def __init__(self, model: torch.nn.Module, compute_device: torch.device):
|
||||
print("CachedModelV2.__init__")
|
||||
self._model = model
|
||||
inject_custom_layers_into_module(self._model)
|
||||
self._compute_device = compute_device
|
||||
|
||||
# Memoized values.
|
||||
self._total_size_cache = None
|
||||
self._cur_vram_bytes_cache = None
|
||||
|
||||
@property
|
||||
def model(self) -> torch.nn.Module:
|
||||
return self._model
|
||||
|
||||
def total_bytes(self) -> int:
|
||||
if self._total_size_cache is None:
|
||||
self._total_size_cache = sum(p.numel() * p.element_size() for p in self._model.parameters())
|
||||
return self._total_size_cache
|
||||
|
||||
def cur_vram_bytes(self) -> int:
|
||||
"""Return the size (in bytes) of the weights that are currently in VRAM."""
|
||||
if self._cur_vram_bytes_cache is None:
|
||||
self._cur_vram_bytes_cache = sum(
|
||||
p.numel() * p.element_size()
|
||||
for p in self._model.parameters()
|
||||
if p.device.type == self._compute_device.type
|
||||
)
|
||||
return self._cur_vram_bytes_cache
|
||||
|
||||
def full_load_to_vram(self):
|
||||
"""Load all weights into VRAM."""
|
||||
raise NotImplementedError("Not implemented")
|
||||
self._cur_vram_bytes_cache = self.total_bytes()
|
||||
|
||||
def partial_load_to_vram(self, vram_bytes_to_load: int) -> int:
|
||||
"""Load more weights into VRAM without exceeding vram_bytes_to_load.
|
||||
|
||||
Returns:
|
||||
The number of bytes loaded into VRAM.
|
||||
"""
|
||||
vram_bytes_loaded = 0
|
||||
|
||||
def to_vram(m: torch.nn.Module):
|
||||
nonlocal vram_bytes_loaded
|
||||
|
||||
if not isinstance(m, CustomLinear):
|
||||
# We don't handle offload of this type of module.
|
||||
return
|
||||
|
||||
m_device = m.weight.device
|
||||
m_bytes = sum(p.numel() * p.element_size() for p in m.parameters())
|
||||
|
||||
# Skip modules that are already on the compute device.
|
||||
if m_device.type == self._compute_device.type:
|
||||
return
|
||||
|
||||
# Check the size of the parameter.
|
||||
if vram_bytes_loaded + m_bytes > vram_bytes_to_load:
|
||||
# TODO(ryand): Should we just break here? If we couldn't fit this parameter into VRAM, is it really
|
||||
# worth continuing to search for a smaller parameter that would fit?
|
||||
return
|
||||
|
||||
vram_bytes_loaded += m_bytes
|
||||
m.to(self._compute_device)
|
||||
|
||||
self._model.apply(to_vram)
|
||||
self._cur_vram_bytes_cache = None
|
||||
|
||||
return vram_bytes_loaded
|
||||
|
||||
def partial_unload_from_vram(self, vram_bytes_to_free: int) -> int:
|
||||
"""Unload weights from VRAM until vram_bytes_to_free bytes are freed. Or the entire model is unloaded."""
|
||||
|
||||
vram_bytes_freed = 0
|
||||
|
||||
def from_vram(m: torch.nn.Module):
|
||||
nonlocal vram_bytes_freed
|
||||
|
||||
if vram_bytes_freed >= vram_bytes_to_free:
|
||||
return
|
||||
|
||||
m_device = m.weight.device
|
||||
m_bytes = sum(p.numel() * p.element_size() for p in m.parameters())
|
||||
if m_device.type != self._compute_device.type:
|
||||
return
|
||||
|
||||
vram_bytes_freed += m_bytes
|
||||
m.to("cpu")
|
||||
|
||||
self._model.apply(from_vram)
|
||||
self._cur_vram_bytes_cache = None
|
||||
|
||||
return vram_bytes_freed
|
||||
18
invokeai/backend/model_cache_v2/torch_autocast_context.py
Normal file
18
invokeai/backend/model_cache_v2/torch_autocast_context.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import torch
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
|
||||
def cast_to_device_and_run(func, args, kwargs, to_device: torch.device):
|
||||
args_on_device = [a.to(to_device) if isinstance(a, torch.Tensor) else a for a in args]
|
||||
kwargs_on_device = {k: v.to(to_device) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
|
||||
return func(*args_on_device, **kwargs_on_device)
|
||||
|
||||
|
||||
class TorchAutocastContext(TorchDispatchMode):
|
||||
def __init__(self, to_device: torch.device):
|
||||
self._to_device = to_device
|
||||
|
||||
def __torch_dispatch__(self, func, types, args, kwargs):
|
||||
# print(f"Dispatch Log: {func}(*{args}, **{kwargs})")
|
||||
# print(f"Dispatch Log: {types}")
|
||||
return cast_to_device_and_run(func, args, kwargs, self._to_device)
|
||||
@@ -0,0 +1,16 @@
|
||||
import torch
|
||||
from torch.overrides import TorchFunctionMode
|
||||
|
||||
|
||||
def cast_to_device_and_run(func, args, kwargs, to_device: torch.device):
|
||||
args_on_device = [a.to(to_device) if isinstance(a, torch.Tensor) else a for a in args]
|
||||
kwargs_on_device = {k: v.to(to_device) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
|
||||
return func(*args_on_device, **kwargs_on_device)
|
||||
|
||||
|
||||
class TorchFunctionAutocastContext(TorchFunctionMode):
|
||||
def __init__(self, to_device: torch.device):
|
||||
self._to_device = to_device
|
||||
|
||||
def __torch_function__(self, func, types, args, kwargs=None):
|
||||
return cast_to_device_and_run(func, args, kwargs or {}, self._to_device)
|
||||
26
invokeai/backend/model_cache_v2/torch_module_overrides.py
Normal file
26
invokeai/backend/model_cache_v2/torch_module_overrides.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from typing import TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
T = TypeVar("T", torch.Tensor, None)
|
||||
|
||||
|
||||
def cast_to_device(t: T, to_device: torch.device, non_blocking: bool = True) -> T:
|
||||
if t is None:
|
||||
return t
|
||||
return t.to(to_device, non_blocking=non_blocking)
|
||||
|
||||
|
||||
def inject_custom_layers_into_module(model: torch.nn.Module):
|
||||
def inject_custom_layers(module: torch.nn.Module):
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
module.__class__ = CustomLinear
|
||||
|
||||
model.apply(inject_custom_layers)
|
||||
|
||||
|
||||
class CustomLinear(torch.nn.Linear):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
bias = cast_to_device(self.bias, input.device)
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
@@ -8,7 +8,7 @@ from pathlib import Path
|
||||
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig, ModelLoaderBase
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_default import ModelCache
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry, ModelLoaderRegistryBase
|
||||
|
||||
# This registers the subclasses that implement loaders of specific model types
|
||||
|
||||
@@ -5,7 +5,6 @@ Base class for model loading in InvokeAI.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Generator, Optional, Tuple
|
||||
@@ -18,19 +17,17 @@ from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadedModelWithoutConfig:
|
||||
"""
|
||||
Context manager object that mediates transfer from RAM<->VRAM.
|
||||
"""Context manager object that mediates transfer from RAM<->VRAM.
|
||||
|
||||
This is a context manager object that has two distinct APIs:
|
||||
|
||||
1. Older API (deprecated):
|
||||
Use the LoadedModel object directly as a context manager.
|
||||
It will move the model into VRAM (on CUDA devices), and
|
||||
Use the LoadedModel object directly as a context manager. It will move the model into VRAM (on CUDA devices), and
|
||||
return the model in a form suitable for passing to torch.
|
||||
Example:
|
||||
```
|
||||
@@ -40,13 +37,9 @@ class LoadedModelWithoutConfig:
|
||||
```
|
||||
|
||||
2. Newer API (recommended):
|
||||
Call the LoadedModel's `model_on_device()` method in a
|
||||
context. It returns a tuple consisting of a copy of
|
||||
the model's state dict in CPU RAM followed by a copy
|
||||
of the model in VRAM. The state dict is provided to allow
|
||||
LoRAs and other model patchers to return the model to
|
||||
its unpatched state without expensive copy and restore
|
||||
operations.
|
||||
Call the LoadedModel's `model_on_device()` method in a context. It returns a tuple consisting of a copy of the
|
||||
model's state dict in CPU RAM followed by a copy of the model in VRAM. The state dict is provided to allow LoRAs and
|
||||
other model patchers to return the model to its unpatched state without expensive copy and restore operations.
|
||||
|
||||
Example:
|
||||
```
|
||||
@@ -55,43 +48,42 @@ class LoadedModelWithoutConfig:
|
||||
image = vae.decode(latents)[0]
|
||||
```
|
||||
|
||||
The state_dict should be treated as a read-only object and
|
||||
never modified. Also be aware that some loadable models do
|
||||
not have a state_dict, in which case this value will be None.
|
||||
The state_dict should be treated as a read-only object and never modified. Also be aware that some loadable models
|
||||
do not have a state_dict, in which case this value will be None.
|
||||
"""
|
||||
|
||||
_locker: ModelLockerBase
|
||||
def __init__(self, cache_record: CacheRecord, cache: ModelCache):
|
||||
self._cache_record = cache_record
|
||||
self._cache = cache
|
||||
|
||||
def __enter__(self) -> AnyModel:
|
||||
"""Context entry."""
|
||||
self._locker.lock()
|
||||
self._cache.lock(self._cache_record.key)
|
||||
return self.model
|
||||
|
||||
def __exit__(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""Context exit."""
|
||||
self._locker.unlock()
|
||||
self._cache.unlock(self._cache_record.key)
|
||||
|
||||
@contextmanager
|
||||
def model_on_device(self) -> Generator[Tuple[Optional[Dict[str, torch.Tensor]], AnyModel], None, None]:
|
||||
"""Return a tuple consisting of the model's state dict (if it exists) and the locked model on execution device."""
|
||||
locked_model = self._locker.lock()
|
||||
self._cache.lock(self._cache_record.key)
|
||||
try:
|
||||
state_dict = self._locker.get_state_dict()
|
||||
yield (state_dict, locked_model)
|
||||
yield (self._cache_record.state_dict, self._cache_record.model)
|
||||
finally:
|
||||
self._locker.unlock()
|
||||
self._cache.unlock(self._cache_record.key)
|
||||
|
||||
@property
|
||||
def model(self) -> AnyModel:
|
||||
"""Return the model without locking it."""
|
||||
return self._locker.model
|
||||
return self._cache_record.model
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadedModel(LoadedModelWithoutConfig):
|
||||
"""Context manager object that mediates transfer from RAM<->VRAM."""
|
||||
|
||||
config: Optional[AnyModelConfig] = None
|
||||
def __init__(self, config: Optional[AnyModelConfig], cache_record: CacheRecord, cache: ModelCache):
|
||||
super().__init__(cache_record=cache_record, cache=cache)
|
||||
self.config = config
|
||||
|
||||
|
||||
# TODO(MM2):
|
||||
@@ -110,7 +102,7 @@ class ModelLoaderBase(ABC):
|
||||
self,
|
||||
app_config: InvokeAIAppConfig,
|
||||
logger: Logger,
|
||||
ram_cache: ModelCacheBase[AnyModel],
|
||||
ram_cache: ModelCache,
|
||||
):
|
||||
"""Initialize the loader."""
|
||||
pass
|
||||
@@ -138,6 +130,6 @@ class ModelLoaderBase(ABC):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
||||
def ram_cache(self) -> ModelCache:
|
||||
"""Return the ram cache associated with this loader."""
|
||||
pass
|
||||
|
||||
@@ -14,7 +14,8 @@ from invokeai.backend.model_manager import (
|
||||
)
|
||||
from invokeai.backend.model_manager.config import DiffusersConfigBase
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache, get_model_cache_key
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
|
||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
@@ -28,7 +29,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
self,
|
||||
app_config: InvokeAIAppConfig,
|
||||
logger: Logger,
|
||||
ram_cache: ModelCacheBase[AnyModel],
|
||||
ram_cache: ModelCache,
|
||||
):
|
||||
"""Initialize the loader."""
|
||||
self._app_config = app_config
|
||||
@@ -54,11 +55,11 @@ class ModelLoader(ModelLoaderBase):
|
||||
raise InvalidModelConfigException(f"Files for model '{model_config.name}' not found at {model_path}")
|
||||
|
||||
with skip_torch_weight_init():
|
||||
locker = self._load_and_cache(model_config, submodel_type)
|
||||
return LoadedModel(config=model_config, _locker=locker)
|
||||
cache_record = self._load_and_cache(model_config, submodel_type)
|
||||
return LoadedModel(config=model_config, cache_record=cache_record, cache=self._ram_cache)
|
||||
|
||||
@property
|
||||
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
||||
def ram_cache(self) -> ModelCache:
|
||||
"""Return the ram cache associated with this loader."""
|
||||
return self._ram_cache
|
||||
|
||||
@@ -66,10 +67,10 @@ class ModelLoader(ModelLoaderBase):
|
||||
model_base = self._app_config.models_path
|
||||
return (model_base / config.path).resolve()
|
||||
|
||||
def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> ModelLockerBase:
|
||||
def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> CacheRecord:
|
||||
stats_name = ":".join([config.base, config.type, config.name, (submodel_type or "")])
|
||||
try:
|
||||
return self._ram_cache.get(config.key, submodel_type, stats_name=stats_name)
|
||||
return self._ram_cache.get(key=get_model_cache_key(config.key, submodel_type), stats_name=stats_name)
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
@@ -78,16 +79,11 @@ class ModelLoader(ModelLoaderBase):
|
||||
loaded_model = self._load_model(config, submodel_type)
|
||||
|
||||
self._ram_cache.put(
|
||||
config.key,
|
||||
submodel_type=submodel_type,
|
||||
get_model_cache_key(config.key, submodel_type),
|
||||
model=loaded_model,
|
||||
)
|
||||
|
||||
return self._ram_cache.get(
|
||||
key=config.key,
|
||||
submodel_type=submodel_type,
|
||||
stats_name=stats_name,
|
||||
)
|
||||
return self._ram_cache.get(key=get_model_cache_key(config.key, submodel_type), stats_name=stats_name)
|
||||
|
||||
def get_size_fs(
|
||||
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
"""Init file for ModelCache."""
|
||||
|
||||
from .model_cache_base import ModelCacheBase, CacheStats # noqa F401
|
||||
from .model_cache_default import ModelCache # noqa F401
|
||||
|
||||
_all__ = ["ModelCacheBase", "ModelCache", "CacheStats"]
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheRecord:
|
||||
"""
|
||||
Elements of the cache:
|
||||
|
||||
key: Unique key for each model, same as used in the models database.
|
||||
model: Model in memory.
|
||||
state_dict: A read-only copy of the model's state dict in RAM. It will be
|
||||
used as a template for creating a copy in the VRAM.
|
||||
size: Size of the model
|
||||
loaded: True if the model's state dict is currently in VRAM
|
||||
|
||||
Before a model is executed, the state_dict template is copied into VRAM,
|
||||
and then injected into the model. When the model is finished, the VRAM
|
||||
copy of the state dict is deleted, and the RAM version is reinjected
|
||||
into the model.
|
||||
|
||||
The state_dict should be treated as a read-only attribute. Do not attempt
|
||||
to patch or otherwise modify it. Instead, patch the copy of the state_dict
|
||||
after it is loaded into the execution device (e.g. CUDA) using the `LoadedModel`
|
||||
context manager call `model_on_device()`.
|
||||
"""
|
||||
|
||||
key: str
|
||||
model: Any
|
||||
device: torch.device
|
||||
state_dict: Optional[Dict[str, torch.Tensor]]
|
||||
size: int
|
||||
loaded: bool = False
|
||||
_locks: int = 0
|
||||
|
||||
def lock(self) -> None:
|
||||
self._locks += 1
|
||||
|
||||
def unlock(self) -> None:
|
||||
self._locks -= 1
|
||||
assert self._locks >= 0
|
||||
|
||||
@property
|
||||
def is_locked(self) -> bool:
|
||||
return self._locks > 0
|
||||
@@ -0,0 +1,15 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheStats(object):
|
||||
"""Collect statistics on cache performance."""
|
||||
|
||||
hits: int = 0 # cache hits
|
||||
misses: int = 0 # cache misses
|
||||
high_watermark: int = 0 # amount of cache used
|
||||
in_cache: int = 0 # number of models in cache
|
||||
cleared: int = 0 # number of models cleared to make space
|
||||
cache_size: int = 0 # total size of cache
|
||||
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)
|
||||
@@ -0,0 +1,69 @@
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class CachedModelOnlyFullLoad:
|
||||
"""A wrapper around a PyTorch model to handle full loads and unloads between the CPU and the compute device.
|
||||
|
||||
Note: "VRAM" is used throughout this class to refer to the memory on the compute device. It could be CUDA memory,
|
||||
MPS memory, etc.
|
||||
"""
|
||||
|
||||
def __init__(self, model: torch.nn.Module | Any, compute_device: torch.device, total_bytes: int):
|
||||
"""Initialize a CachedModelOnlyFullLoad.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module | Any): The model to wrap. Should be on the CPU.
|
||||
compute_device (torch.device): The compute device to move the model to.
|
||||
total_bytes (int): The total size (in bytes) of all the weights in the model.
|
||||
"""
|
||||
# model is often a torch.nn.Module, but could be any model type. Throughout this class, we handle both cases.
|
||||
self._model = model
|
||||
self._compute_device = compute_device
|
||||
self._total_bytes = total_bytes
|
||||
self._is_in_vram = False
|
||||
|
||||
@property
|
||||
def model(self) -> torch.nn.Module:
|
||||
return self._model
|
||||
|
||||
def total_bytes(self) -> int:
|
||||
"""Get the total size (in bytes) of all the weights in the model."""
|
||||
return self._total_bytes
|
||||
|
||||
def is_in_vram(self) -> bool:
|
||||
"""Return true if the model is currently in VRAM."""
|
||||
return self._is_in_vram
|
||||
|
||||
def full_load_to_vram(self) -> int:
|
||||
"""Load all weights into VRAM (if supported by the model).
|
||||
|
||||
Returns:
|
||||
The number of bytes loaded into VRAM.
|
||||
"""
|
||||
if self._is_in_vram:
|
||||
# Already in VRAM.
|
||||
return 0
|
||||
|
||||
if not hasattr(self._model, "to"):
|
||||
# Model doesn't support moving to a device.
|
||||
return 0
|
||||
|
||||
self._model.to(self._compute_device)
|
||||
self._is_in_vram = True
|
||||
return self._total_bytes
|
||||
|
||||
def full_unload_from_vram(self) -> int:
|
||||
"""Unload all weights from VRAM.
|
||||
|
||||
Returns:
|
||||
The number of bytes unloaded from VRAM.
|
||||
"""
|
||||
if not self._is_in_vram:
|
||||
# Already in RAM.
|
||||
return 0
|
||||
|
||||
self._model.to("cpu")
|
||||
self._is_in_vram = False
|
||||
return self._total_bytes
|
||||
@@ -0,0 +1,84 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
|
||||
|
||||
|
||||
class CachedModelWithPartialLoad:
|
||||
"""A wrapper around a PyTorch model to handle partial loads and unloads between the CPU and the compute device.
|
||||
|
||||
Note: "VRAM" is used throughout this class to refer to the memory on the compute device. It could be CUDA memory,
|
||||
MPS memory, etc.
|
||||
"""
|
||||
|
||||
def __init__(self, model: torch.nn.Module, compute_device: torch.device):
|
||||
self._model = model
|
||||
self._compute_device = compute_device
|
||||
|
||||
# TODO(ryand): Add memoization for total_bytes and cur_vram_bytes?
|
||||
|
||||
@property
|
||||
def model(self) -> torch.nn.Module:
|
||||
return self._model
|
||||
|
||||
def total_bytes(self) -> int:
|
||||
"""Get the total size (in bytes) of all the weights in the model."""
|
||||
return sum(calc_tensor_size(p) for p in self._model.parameters())
|
||||
|
||||
def cur_vram_bytes(self) -> int:
|
||||
"""Get the size (in bytes) of the weights that are currently in VRAM."""
|
||||
return sum(calc_tensor_size(p) for p in self._model.parameters() if p.device.type == self._compute_device.type)
|
||||
|
||||
def partial_load_to_vram(self, vram_bytes_to_load: int) -> int:
|
||||
"""Load more weights into VRAM without exceeding vram_bytes_to_load.
|
||||
|
||||
Returns:
|
||||
The number of bytes loaded into VRAM.
|
||||
"""
|
||||
vram_bytes_loaded = 0
|
||||
|
||||
# TODO(ryand): Should we use self._model.apply(...) instead and move modules around instead of moving tensors?
|
||||
# This way we don't have to use the private _apply() method.
|
||||
def to_vram(t: torch.Tensor):
|
||||
nonlocal vram_bytes_loaded
|
||||
|
||||
# Skip parameters that are already on the compute device.
|
||||
if t.device.type == self._compute_device.type:
|
||||
return t
|
||||
|
||||
# Check the size of the parameter.
|
||||
param_size = calc_tensor_size(t)
|
||||
if vram_bytes_loaded + param_size > vram_bytes_to_load:
|
||||
# TODO(ryand): Should we just break here? If we couldn't fit this parameter into VRAM, is it really
|
||||
# worth continuing to search for a smaller parameter that would fit?
|
||||
return t
|
||||
|
||||
vram_bytes_loaded += param_size
|
||||
return t.to(self._compute_device)
|
||||
|
||||
self._model._apply(to_vram)
|
||||
|
||||
return vram_bytes_loaded
|
||||
|
||||
def partial_unload_from_vram(self, vram_bytes_to_free: int) -> int:
|
||||
"""Unload weights from VRAM until vram_bytes_to_free bytes are freed. Or the entire model is unloaded.
|
||||
|
||||
Returns:
|
||||
The number of bytes unloaded from VRAM.
|
||||
"""
|
||||
vram_bytes_freed = 0
|
||||
|
||||
def from_vram(t: torch.Tensor):
|
||||
nonlocal vram_bytes_freed
|
||||
|
||||
if vram_bytes_freed >= vram_bytes_to_free:
|
||||
return t
|
||||
|
||||
if t.device.type != self._compute_device.type:
|
||||
return t
|
||||
|
||||
vram_bytes_freed += calc_tensor_size(t)
|
||||
return t.to("cpu")
|
||||
|
||||
self._model._apply(from_vram)
|
||||
|
||||
return vram_bytes_freed
|
||||
@@ -1,11 +1,9 @@
|
||||
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
|
||||
# TODO: Add Stalker's proper name to copyright
|
||||
""" """
|
||||
|
||||
import gc
|
||||
import math
|
||||
import time
|
||||
from contextlib import suppress
|
||||
from logging import Logger
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
@@ -13,13 +11,8 @@ import torch
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel, SubModelType
|
||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import (
|
||||
CacheRecord,
|
||||
CacheStats,
|
||||
ModelCacheBase,
|
||||
ModelLockerBase,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.model_locker import ModelLocker
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
@@ -31,7 +24,14 @@ GB = 2**30
|
||||
MB = 2**20
|
||||
|
||||
|
||||
class ModelCache(ModelCacheBase[AnyModel]):
|
||||
def get_model_cache_key(model_key: str, submodel_type: Optional[SubModelType] = None) -> str:
|
||||
if submodel_type:
|
||||
return f"{model_key}:{submodel_type.value}"
|
||||
else:
|
||||
return model_key
|
||||
|
||||
|
||||
class ModelCache:
|
||||
"""A cache for managing models in memory.
|
||||
|
||||
The cache is based on two levels of model storage:
|
||||
@@ -70,7 +70,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
max_vram_cache_size: float,
|
||||
execution_device: torch.device = torch.device("cuda"),
|
||||
storage_device: torch.device = torch.device("cpu"),
|
||||
precision: torch.dtype = torch.float16,
|
||||
lazy_offloading: bool = True,
|
||||
log_memory_usage: bool = False,
|
||||
logger: Optional[Logger] = None,
|
||||
@@ -82,7 +81,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
:param max_vram_cache_size: Maximum size of the execution_device cache in GBs.
|
||||
:param execution_device: Torch device to load active model into [torch.device('cuda')]
|
||||
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
||||
:param precision: Precision for loaded models [torch.float16]
|
||||
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
||||
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
|
||||
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
|
||||
@@ -100,29 +98,9 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
self._log_memory_usage = log_memory_usage
|
||||
self._stats: Optional[CacheStats] = None
|
||||
|
||||
self._cached_models: Dict[str, CacheRecord[AnyModel]] = {}
|
||||
self._cached_models: Dict[str, CacheRecord] = {}
|
||||
self._cache_stack: List[str] = []
|
||||
|
||||
@property
|
||||
def logger(self) -> Logger:
|
||||
"""Return the logger used by the cache."""
|
||||
return self._logger
|
||||
|
||||
@property
|
||||
def lazy_offloading(self) -> bool:
|
||||
"""Return true if the cache is configured to lazily offload models in VRAM."""
|
||||
return self._lazy_offloading
|
||||
|
||||
@property
|
||||
def storage_device(self) -> torch.device:
|
||||
"""Return the storage device (e.g. "CPU" for RAM)."""
|
||||
return self._storage_device
|
||||
|
||||
@property
|
||||
def execution_device(self) -> torch.device:
|
||||
"""Return the exection device (e.g. "cuda" for VRAM)."""
|
||||
return self._execution_device
|
||||
|
||||
@property
|
||||
def max_cache_size(self) -> float:
|
||||
"""Return the cap on cache size."""
|
||||
@@ -153,49 +131,26 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
"""Set the CacheStats object for collectin cache statistics."""
|
||||
self._stats = stats
|
||||
|
||||
def cache_size(self) -> int:
|
||||
"""Get the total size of the models currently cached."""
|
||||
total = 0
|
||||
for cache_record in self._cached_models.values():
|
||||
total += cache_record.size
|
||||
return total
|
||||
|
||||
def put(
|
||||
self,
|
||||
key: str,
|
||||
model: AnyModel,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> None:
|
||||
"""Store model under key and optional submodel_type."""
|
||||
key = self._make_cache_key(key, submodel_type)
|
||||
def put(self, key: str, model: AnyModel) -> None:
|
||||
if key in self._cached_models:
|
||||
return
|
||||
size = calc_model_size_by_data(self.logger, model)
|
||||
size = calc_model_size_by_data(self._logger, model)
|
||||
self.make_room(size)
|
||||
|
||||
running_on_cpu = self.execution_device == torch.device("cpu")
|
||||
running_on_cpu = self._execution_device == torch.device("cpu")
|
||||
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) and not running_on_cpu else None
|
||||
cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size)
|
||||
cache_record = CacheRecord(key=key, model=model, device=self._storage_device, state_dict=state_dict, size=size)
|
||||
self._cached_models[key] = cache_record
|
||||
self._cache_stack.append(key)
|
||||
|
||||
def get(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
stats_name: Optional[str] = None,
|
||||
) -> ModelLockerBase:
|
||||
"""
|
||||
Retrieve model using key and optional submodel_type.
|
||||
def get(self, key: str, stats_name: Optional[str] = None) -> CacheRecord:
|
||||
"""Retrieve a model from the cache.
|
||||
|
||||
:param key: Opaque model key
|
||||
:param submodel_type: Type of the submodel to fetch
|
||||
:param stats_name: A human-readable id for the model for the purposes of
|
||||
stats reporting.
|
||||
:param key: Model key
|
||||
:param stats_name: A human-readable id for the model for the purposes of stats reporting.
|
||||
|
||||
This may raise an IndexError if the model is not in the cache.
|
||||
Raises IndexError if the model is not in the cache.
|
||||
"""
|
||||
key = self._make_cache_key(key, submodel_type)
|
||||
if key in self._cached_models:
|
||||
if self.stats:
|
||||
self.stats.hits += 1
|
||||
@@ -210,20 +165,52 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
if self.stats:
|
||||
stats_name = stats_name or key
|
||||
self.stats.cache_size = int(self._max_cache_size * GB)
|
||||
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
|
||||
self.stats.high_watermark = max(self.stats.high_watermark, self._get_cache_size())
|
||||
self.stats.in_cache = len(self._cached_models)
|
||||
self.stats.loaded_model_sizes[stats_name] = max(
|
||||
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size
|
||||
)
|
||||
|
||||
# this moves the entry to the top (right end) of the stack
|
||||
with suppress(Exception):
|
||||
self._cache_stack.remove(key)
|
||||
self._cache_stack = [k for k in self._cache_stack if k != key]
|
||||
self._cache_stack.append(key)
|
||||
return ModelLocker(
|
||||
cache=self,
|
||||
cache_entry=cache_entry,
|
||||
)
|
||||
|
||||
return cache_entry
|
||||
|
||||
def lock(self, key: str) -> None:
|
||||
"""Lock a model for use and move it into VRAM."""
|
||||
cache_entry = self._cached_models[key]
|
||||
cache_entry.lock()
|
||||
|
||||
try:
|
||||
if self._lazy_offloading:
|
||||
self._offload_unlocked_models(cache_entry.size)
|
||||
self._move_model_to_device(cache_entry, self._execution_device)
|
||||
cache_entry.loaded = True
|
||||
self._logger.debug(f"Locking {cache_entry.key} in {self._execution_device}")
|
||||
self._print_cuda_stats()
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
self._logger.warning("Insufficient GPU memory to load model. Aborting")
|
||||
cache_entry.unlock()
|
||||
raise
|
||||
except Exception:
|
||||
cache_entry.unlock()
|
||||
raise
|
||||
|
||||
def unlock(self, key: str) -> None:
|
||||
"""Unlock a model."""
|
||||
cache_entry = self._cached_models[key]
|
||||
cache_entry.unlock()
|
||||
if not self._lazy_offloading:
|
||||
self._offload_unlocked_models(0)
|
||||
self._print_cuda_stats()
|
||||
|
||||
def _get_cache_size(self) -> int:
|
||||
"""Get the total size of the models currently cached."""
|
||||
total = 0
|
||||
for cache_record in self._cached_models.values():
|
||||
total += cache_record.size
|
||||
return total
|
||||
|
||||
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
|
||||
if self._log_memory_usage:
|
||||
@@ -236,30 +223,30 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
else:
|
||||
return model_key
|
||||
|
||||
def offload_unlocked_models(self, size_required: int) -> None:
|
||||
def _offload_unlocked_models(self, size_required: int) -> None:
|
||||
"""Offload models from the execution_device to make room for size_required.
|
||||
|
||||
:param size_required: The amount of space to clear in the execution_device cache, in bytes.
|
||||
"""
|
||||
reserved = self._max_vram_cache_size * GB
|
||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||
self.logger.debug(f"{(vram_in_use/GB):.2f}GB VRAM needed for models; max allowed={(reserved/GB):.2f}GB")
|
||||
self._logger.debug(f"{(vram_in_use/GB):.2f}GB VRAM needed for models; max allowed={(reserved/GB):.2f}GB")
|
||||
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
|
||||
if vram_in_use <= reserved:
|
||||
break
|
||||
if not cache_entry.loaded:
|
||||
continue
|
||||
if not cache_entry.locked:
|
||||
self.move_model_to_device(cache_entry, self.storage_device)
|
||||
if not cache_entry.is_locked:
|
||||
self._move_model_to_device(cache_entry, self._storage_device)
|
||||
cache_entry.loaded = False
|
||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||
self.logger.debug(
|
||||
self._logger.debug(
|
||||
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GB):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GB):.2f}GB"
|
||||
)
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
|
||||
def _move_model_to_device(self, cache_entry: CacheRecord, target_device: torch.device) -> None:
|
||||
"""Move model into the indicated device.
|
||||
|
||||
:param cache_entry: The CacheRecord for the model
|
||||
@@ -267,7 +254,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
|
||||
May raise a torch.cuda.OutOfMemoryError
|
||||
"""
|
||||
self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
|
||||
self._logger.debug(f"Called to move {cache_entry.key} to {target_device}")
|
||||
source_device = cache_entry.device
|
||||
|
||||
# Note: We compare device types only so that 'cuda' == 'cuda:0'.
|
||||
@@ -294,7 +281,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
try:
|
||||
if cache_entry.state_dict is not None:
|
||||
assert hasattr(cache_entry.model, "load_state_dict")
|
||||
if target_device == self.storage_device:
|
||||
if target_device == self._storage_device:
|
||||
cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True)
|
||||
else:
|
||||
new_dict: Dict[str, torch.Tensor] = {}
|
||||
@@ -309,7 +296,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
|
||||
snapshot_after = self._capture_memory_snapshot()
|
||||
end_model_to_time = time.time()
|
||||
self.logger.debug(
|
||||
self._logger.debug(
|
||||
f"Moved model '{cache_entry.key}' from {source_device} to"
|
||||
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
|
||||
f"Estimated model size: {(cache_entry.size/GB):.3f} GB."
|
||||
@@ -331,7 +318,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
rel_tol=0.1,
|
||||
abs_tol=10 * MB,
|
||||
):
|
||||
self.logger.debug(
|
||||
self._logger.debug(
|
||||
f"Moving model '{cache_entry.key}' from {source_device} to"
|
||||
f" {target_device} caused an unexpected change in VRAM usage. The model's"
|
||||
" estimated size may be incorrect. Estimated model size:"
|
||||
@@ -339,24 +326,24 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||
)
|
||||
|
||||
def print_cuda_stats(self) -> None:
|
||||
def _print_cuda_stats(self) -> None:
|
||||
"""Log CUDA diagnostics."""
|
||||
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GB)
|
||||
ram = "%4.2fG" % (self.cache_size() / GB)
|
||||
ram = "%4.2fG" % (self._get_cache_size() / GB)
|
||||
|
||||
in_ram_models = 0
|
||||
in_vram_models = 0
|
||||
locked_in_vram_models = 0
|
||||
for cache_record in self._cached_models.values():
|
||||
if hasattr(cache_record.model, "device"):
|
||||
if cache_record.model.device == self.storage_device:
|
||||
if cache_record.model.device == self._storage_device:
|
||||
in_ram_models += 1
|
||||
else:
|
||||
in_vram_models += 1
|
||||
if cache_record.locked:
|
||||
if cache_record.is_locked:
|
||||
locked_in_vram_models += 1
|
||||
|
||||
self.logger.debug(
|
||||
self._logger.debug(
|
||||
f"Current VRAM/RAM usage: {vram}/{ram}; models_in_ram/models_in_vram(locked) ="
|
||||
f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})"
|
||||
)
|
||||
@@ -369,16 +356,16 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
garbage-collected.
|
||||
"""
|
||||
bytes_needed = size
|
||||
maximum_size = self.max_cache_size * GB # stored in GB, convert to bytes
|
||||
current_size = self.cache_size()
|
||||
maximum_size = self._max_cache_size * GB # stored in GB, convert to bytes
|
||||
current_size = self._get_cache_size()
|
||||
|
||||
if current_size + bytes_needed > maximum_size:
|
||||
self.logger.debug(
|
||||
self._logger.debug(
|
||||
f"Max cache size exceeded: {(current_size/GB):.2f}/{self.max_cache_size:.2f} GB, need an additional"
|
||||
f" {(bytes_needed/GB):.2f} GB"
|
||||
)
|
||||
|
||||
self.logger.debug(f"Before making_room: cached_models={len(self._cached_models)}")
|
||||
self._logger.debug(f"Before making_room: cached_models={len(self._cached_models)}")
|
||||
|
||||
pos = 0
|
||||
models_cleared = 0
|
||||
@@ -386,12 +373,12 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
model_key = self._cache_stack[pos]
|
||||
cache_entry = self._cached_models[model_key]
|
||||
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
|
||||
self.logger.debug(
|
||||
self._logger.debug(
|
||||
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}"
|
||||
)
|
||||
|
||||
if not cache_entry.locked:
|
||||
self.logger.debug(
|
||||
if not cache_entry.is_locked:
|
||||
self._logger.debug(
|
||||
f"Removing {model_key} from RAM cache to free at least {(size/GB):.2f} GB (-{(cache_entry.size/GB):.2f} GB)"
|
||||
)
|
||||
current_size -= cache_entry.size
|
||||
@@ -419,8 +406,8 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
gc.collect()
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
self.logger.debug(f"After making room: cached_models={len(self._cached_models)}")
|
||||
self._logger.debug(f"After making room: cached_models={len(self._cached_models)}")
|
||||
|
||||
def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None:
|
||||
def _delete_cache_entry(self, cache_entry: CacheRecord) -> None:
|
||||
self._cache_stack.remove(cache_entry.key)
|
||||
del self._cached_models[cache_entry.key]
|
||||
@@ -1,221 +0,0 @@
|
||||
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
|
||||
# TODO: Add Stalker's proper name to copyright
|
||||
"""
|
||||
Manage a RAM cache of diffusion/transformer models for fast switching.
|
||||
They are moved between GPU VRAM and CPU RAM as necessary. If the cache
|
||||
grows larger than a preset maximum, then the least recently used
|
||||
model will be cleared and (re)loaded from disk when next needed.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from logging import Logger
|
||||
from typing import Dict, Generic, Optional, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.config import AnyModel, SubModelType
|
||||
|
||||
|
||||
class ModelLockerBase(ABC):
|
||||
"""Base class for the model locker used by the loader."""
|
||||
|
||||
@abstractmethod
|
||||
def lock(self) -> AnyModel:
|
||||
"""Lock the contained model and move it into VRAM."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def unlock(self) -> None:
|
||||
"""Unlock the contained model, and remove it from VRAM."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
|
||||
"""Return the state dict (if any) for the cached model."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def model(self) -> AnyModel:
|
||||
"""Return the model."""
|
||||
pass
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheRecord(Generic[T]):
|
||||
"""
|
||||
Elements of the cache:
|
||||
|
||||
key: Unique key for each model, same as used in the models database.
|
||||
model: Model in memory.
|
||||
state_dict: A read-only copy of the model's state dict in RAM. It will be
|
||||
used as a template for creating a copy in the VRAM.
|
||||
size: Size of the model
|
||||
loaded: True if the model's state dict is currently in VRAM
|
||||
|
||||
Before a model is executed, the state_dict template is copied into VRAM,
|
||||
and then injected into the model. When the model is finished, the VRAM
|
||||
copy of the state dict is deleted, and the RAM version is reinjected
|
||||
into the model.
|
||||
|
||||
The state_dict should be treated as a read-only attribute. Do not attempt
|
||||
to patch or otherwise modify it. Instead, patch the copy of the state_dict
|
||||
after it is loaded into the execution device (e.g. CUDA) using the `LoadedModel`
|
||||
context manager call `model_on_device()`.
|
||||
"""
|
||||
|
||||
key: str
|
||||
model: T
|
||||
device: torch.device
|
||||
state_dict: Optional[Dict[str, torch.Tensor]]
|
||||
size: int
|
||||
loaded: bool = False
|
||||
_locks: int = 0
|
||||
|
||||
def lock(self) -> None:
|
||||
"""Lock this record."""
|
||||
self._locks += 1
|
||||
|
||||
def unlock(self) -> None:
|
||||
"""Unlock this record."""
|
||||
self._locks -= 1
|
||||
assert self._locks >= 0
|
||||
|
||||
@property
|
||||
def locked(self) -> bool:
|
||||
"""Return true if record is locked."""
|
||||
return self._locks > 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheStats(object):
|
||||
"""Collect statistics on cache performance."""
|
||||
|
||||
hits: int = 0 # cache hits
|
||||
misses: int = 0 # cache misses
|
||||
high_watermark: int = 0 # amount of cache used
|
||||
in_cache: int = 0 # number of models in cache
|
||||
cleared: int = 0 # number of models cleared to make space
|
||||
cache_size: int = 0 # total size of cache
|
||||
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)
|
||||
|
||||
|
||||
class ModelCacheBase(ABC, Generic[T]):
|
||||
"""Virtual base class for RAM model cache."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def storage_device(self) -> torch.device:
|
||||
"""Return the storage device (e.g. "CPU" for RAM)."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def execution_device(self) -> torch.device:
|
||||
"""Return the exection device (e.g. "cuda" for VRAM)."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def lazy_offloading(self) -> bool:
|
||||
"""Return true if the cache is configured to lazily offload models in VRAM."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def max_cache_size(self) -> float:
|
||||
"""Return the maximum size the RAM cache can grow to."""
|
||||
pass
|
||||
|
||||
@max_cache_size.setter
|
||||
@abstractmethod
|
||||
def max_cache_size(self, value: float) -> None:
|
||||
"""Set the cap on vram cache size."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def max_vram_cache_size(self) -> float:
|
||||
"""Return the maximum size the VRAM cache can grow to."""
|
||||
pass
|
||||
|
||||
@max_vram_cache_size.setter
|
||||
@abstractmethod
|
||||
def max_vram_cache_size(self, value: float) -> float:
|
||||
"""Set the maximum size the VRAM cache can grow to."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def offload_unlocked_models(self, size_required: int) -> None:
|
||||
"""Offload from VRAM any models not actively in use."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
|
||||
"""Move model into the indicated device."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def stats(self) -> Optional[CacheStats]:
|
||||
"""Return collected CacheStats object."""
|
||||
pass
|
||||
|
||||
@stats.setter
|
||||
@abstractmethod
|
||||
def stats(self, stats: CacheStats) -> None:
|
||||
"""Set the CacheStats object for collectin cache statistics."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def logger(self) -> Logger:
|
||||
"""Return the logger used by the cache."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def make_room(self, size: int) -> None:
|
||||
"""Make enough room in the cache to accommodate a new model of indicated size."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def put(
|
||||
self,
|
||||
key: str,
|
||||
model: T,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> None:
|
||||
"""Store model under key and optional submodel_type."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
stats_name: Optional[str] = None,
|
||||
) -> ModelLockerBase:
|
||||
"""
|
||||
Retrieve model using key and optional submodel_type.
|
||||
|
||||
:param key: Opaque model key
|
||||
:param submodel_type: Type of the submodel to fetch
|
||||
:param stats_name: A human-readable id for the model for the purposes of
|
||||
stats reporting.
|
||||
|
||||
This may raise an IndexError if the model is not in the cache.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cache_size(self) -> int:
|
||||
"""Get the total size of the models currently cached."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def print_cuda_stats(self) -> None:
|
||||
"""Log debugging information on CUDA usage."""
|
||||
pass
|
||||
@@ -1,64 +0,0 @@
|
||||
"""
|
||||
Base class and implementation of a class that moves models in and out of VRAM.
|
||||
"""
|
||||
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import (
|
||||
CacheRecord,
|
||||
ModelCacheBase,
|
||||
ModelLockerBase,
|
||||
)
|
||||
|
||||
|
||||
class ModelLocker(ModelLockerBase):
|
||||
"""Internal class that mediates movement in and out of GPU."""
|
||||
|
||||
def __init__(self, cache: ModelCacheBase[AnyModel], cache_entry: CacheRecord[AnyModel]):
|
||||
"""
|
||||
Initialize the model locker.
|
||||
|
||||
:param cache: The ModelCache object
|
||||
:param cache_entry: The entry in the model cache
|
||||
"""
|
||||
self._cache = cache
|
||||
self._cache_entry = cache_entry
|
||||
|
||||
@property
|
||||
def model(self) -> AnyModel:
|
||||
"""Return the model without moving it around."""
|
||||
return self._cache_entry.model
|
||||
|
||||
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
|
||||
"""Return the state dict (if any) for the cached model."""
|
||||
return self._cache_entry.state_dict
|
||||
|
||||
def lock(self) -> AnyModel:
|
||||
"""Move the model into the execution device (GPU) and lock it."""
|
||||
self._cache_entry.lock()
|
||||
try:
|
||||
if self._cache.lazy_offloading:
|
||||
self._cache.offload_unlocked_models(self._cache_entry.size)
|
||||
self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device)
|
||||
self._cache_entry.loaded = True
|
||||
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}")
|
||||
self._cache.print_cuda_stats()
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
self._cache.logger.warning("Insufficient GPU memory to load model. Aborting")
|
||||
self._cache_entry.unlock()
|
||||
raise
|
||||
except Exception:
|
||||
self._cache_entry.unlock()
|
||||
raise
|
||||
|
||||
return self.model
|
||||
|
||||
def unlock(self) -> None:
|
||||
"""Call upon exit from context."""
|
||||
self._cache_entry.unlock()
|
||||
if not self._cache.lazy_offloading:
|
||||
self._cache.offload_unlocked_models(0)
|
||||
self._cache.print_cuda_stats()
|
||||
@@ -26,7 +26,7 @@ from invokeai.backend.model_manager import (
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ class LoRALoader(ModelLoader):
|
||||
self,
|
||||
app_config: InvokeAIAppConfig,
|
||||
logger: Logger,
|
||||
ram_cache: ModelCacheBase[AnyModel],
|
||||
ram_cache: ModelCache,
|
||||
):
|
||||
"""Initialize the loader."""
|
||||
super().__init__(app_config, logger, ram_cache)
|
||||
|
||||
@@ -25,6 +25,7 @@ from invokeai.backend.model_manager.config import (
|
||||
DiffusersConfigBase,
|
||||
MainCheckpointConfig,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import get_model_cache_key
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
@@ -132,5 +133,5 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
if subtype == submodel_type:
|
||||
continue
|
||||
if submodel := getattr(pipeline, subtype.value, None):
|
||||
self._ram_cache.put(config.key, submodel_type=subtype, model=submodel)
|
||||
self._ram_cache.put(get_model_cache_key(config.key, subtype), model=submodel)
|
||||
return getattr(pipeline, submodel_type.value)
|
||||
|
||||
@@ -96,7 +96,9 @@
|
||||
"new": "Neu",
|
||||
"ok": "OK",
|
||||
"close": "Schließen",
|
||||
"clipboard": "Zwischenablage"
|
||||
"clipboard": "Zwischenablage",
|
||||
"generating": "Generieren",
|
||||
"loadingModel": "Lade Modell"
|
||||
},
|
||||
"gallery": {
|
||||
"galleryImageSize": "Bildgröße",
|
||||
@@ -591,7 +593,15 @@
|
||||
"loraTriggerPhrases": "LoRA-Auslösephrasen",
|
||||
"installingBundle": "Bündel wird installiert",
|
||||
"triggerPhrases": "Auslösephrasen",
|
||||
"mainModelTriggerPhrases": "Hauptmodell-Auslösephrasen"
|
||||
"mainModelTriggerPhrases": "Hauptmodell-Auslösephrasen",
|
||||
"noDefaultSettings": "Für dieses Modell sind keine Standardeinstellungen konfiguriert. Besuchen Sie den Modell-Manager, um Standardeinstellungen hinzuzufügen.",
|
||||
"defaultSettingsOutOfSync": "Einige Einstellungen stimmen nicht mit den Standardeinstellungen des Modells überein:",
|
||||
"clipLEmbed": "CLIP-L einbetten",
|
||||
"clipGEmbed": "CLIP-G einbetten",
|
||||
"hfTokenLabel": "HuggingFace-Token (für einige Modelle erforderlich)",
|
||||
"hfTokenHelperText": "Für die Nutzung einiger Modelle ist ein HF-Token erforderlich. Klicken Sie hier, um Ihr Token zu erstellen oder zu erhalten.",
|
||||
"hfForbidden": "Sie haben keinen Zugriff auf dieses HF-Modell",
|
||||
"hfTokenInvalid": "Ungültiges oder fehlendes HF-Token"
|
||||
},
|
||||
"parameters": {
|
||||
"images": "Bilder",
|
||||
@@ -841,7 +851,8 @@
|
||||
"upscaling": "Hochskalierung",
|
||||
"canvas": "Leinwand",
|
||||
"prompts_one": "Prompt",
|
||||
"prompts_other": "Prompts"
|
||||
"prompts_other": "Prompts",
|
||||
"batchSize": "Stapelgröße"
|
||||
},
|
||||
"metadata": {
|
||||
"negativePrompt": "Negativ Beschreibung",
|
||||
@@ -1081,6 +1092,21 @@
|
||||
},
|
||||
"patchmatchDownScaleSize": {
|
||||
"heading": "Herunterskalieren"
|
||||
},
|
||||
"paramHeight": {
|
||||
"heading": "Höhe",
|
||||
"paragraphs": [
|
||||
"Höhe des generierten Bildes. Muss ein Vielfaches von 8 sein."
|
||||
]
|
||||
},
|
||||
"paramUpscaleMethod": {
|
||||
"heading": "Vergrößerungsmethode",
|
||||
"paragraphs": [
|
||||
"Methode zum Hochskalieren des Bildes für High Resolution Fix."
|
||||
]
|
||||
},
|
||||
"paramHrf": {
|
||||
"heading": "High Resolution Fix aktivieren"
|
||||
}
|
||||
},
|
||||
"invocationCache": {
|
||||
|
||||
@@ -176,7 +176,8 @@
|
||||
"reset": "Reset",
|
||||
"none": "None",
|
||||
"new": "New",
|
||||
"generating": "Generating"
|
||||
"generating": "Generating",
|
||||
"warnings": "Warnings"
|
||||
},
|
||||
"hrf": {
|
||||
"hrf": "High Resolution Fix",
|
||||
@@ -1038,20 +1039,7 @@
|
||||
"canvasIsSelectingObject": "Canvas is busy (selecting object)",
|
||||
"noPrompts": "No prompts generated",
|
||||
"noNodesInGraph": "No nodes in graph",
|
||||
"systemDisconnected": "System disconnected",
|
||||
"layer": {
|
||||
"controlAdapterNoModelSelected": "no Control Adapter model selected",
|
||||
"controlAdapterIncompatibleBaseModel": "incompatible Control Adapter base model",
|
||||
"t2iAdapterIncompatibleBboxWidth": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}}, bbox width is {{width}}",
|
||||
"t2iAdapterIncompatibleBboxHeight": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}}, bbox height is {{height}}",
|
||||
"t2iAdapterIncompatibleScaledBboxWidth": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}}, scaled bbox width is {{width}}",
|
||||
"t2iAdapterIncompatibleScaledBboxHeight": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}}, scaled bbox height is {{height}}",
|
||||
"ipAdapterNoModelSelected": "no IP adapter selected",
|
||||
"ipAdapterIncompatibleBaseModel": "incompatible IP Adapter base model",
|
||||
"ipAdapterNoImageSelected": "no IP Adapter image selected",
|
||||
"rgNoPromptsOrIPAdapters": "no text prompts or IP Adapters",
|
||||
"rgNoRegion": "no region selected"
|
||||
}
|
||||
"systemDisconnected": "System disconnected"
|
||||
},
|
||||
"maskBlur": "Mask Blur",
|
||||
"negativePromptPlaceholder": "Negative Prompt",
|
||||
@@ -1713,6 +1701,8 @@
|
||||
"controlLayer": "Control Layer",
|
||||
"inpaintMask": "Inpaint Mask",
|
||||
"regionalGuidance": "Regional Guidance",
|
||||
"referenceImageRegional": "Reference Image (Regional)",
|
||||
"referenceImageGlobal": "Reference Image (Global)",
|
||||
"asRasterLayer": "As $t(controlLayers.rasterLayer)",
|
||||
"asRasterLayerResize": "As $t(controlLayers.rasterLayer) (Resize)",
|
||||
"asControlLayer": "As $t(controlLayers.controlLayer)",
|
||||
@@ -1798,6 +1788,21 @@
|
||||
"replaceCurrent": "Replace Current",
|
||||
"controlLayerEmptyState": "<UploadButton>Upload an image</UploadButton>, drag an image from the <GalleryButton>gallery</GalleryButton> onto this layer, or draw on the canvas to get started.",
|
||||
"referenceImageEmptyState": "<UploadButton>Upload an image</UploadButton> or drag an image from the <GalleryButton>gallery</GalleryButton> onto this layer to get started.",
|
||||
"warnings": {
|
||||
"problemsFound": "Problems found",
|
||||
"unsupportedModel": "layer not supported for selected base model",
|
||||
"controlAdapterNoModelSelected": "no Control Layer model selected",
|
||||
"controlAdapterIncompatibleBaseModel": "incompatible Control Layer base model",
|
||||
"controlAdapterNoControl": "no control selected/drawn",
|
||||
"ipAdapterNoModelSelected": "no Reference Image model selected",
|
||||
"ipAdapterIncompatibleBaseModel": "incompatible Reference Image base model",
|
||||
"ipAdapterNoImageSelected": "no Reference Image image selected",
|
||||
"rgNoPromptsOrIPAdapters": "no text prompts or Reference Images",
|
||||
"rgNegativePromptNotSupported": "Negative Prompt not supported for selected base model",
|
||||
"rgReferenceImagesNotSupported": "regional Reference Images not supported for selected base model",
|
||||
"rgAutoNegativeNotSupported": "Auto-Negative not supported for selected base model",
|
||||
"rgNoRegion": "no region drawn"
|
||||
},
|
||||
"controlMode": {
|
||||
"controlMode": "Control Mode",
|
||||
"balanced": "Balanced (recommended)",
|
||||
|
||||
@@ -327,7 +327,6 @@
|
||||
"t2iAdapterIncompatibleBboxHeight": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}}, la hauteur de la bounding box est {{height}}",
|
||||
"t2iAdapterIncompatibleBboxWidth": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}}, la largeur de la bounding box est {{width}}",
|
||||
"ipAdapterIncompatibleBaseModel": "modèle de base d'IP adapter incompatible",
|
||||
"rgNoRegion": "aucune zone sélectionnée",
|
||||
"controlAdapterNoModelSelected": "aucun modèle de Control Adapter sélectionné"
|
||||
},
|
||||
"noPrompts": "Aucun prompts généré",
|
||||
|
||||
@@ -96,7 +96,8 @@
|
||||
"clipboard": "Appunti",
|
||||
"ok": "Ok",
|
||||
"generating": "Generazione",
|
||||
"loadingModel": "Caricamento del modello"
|
||||
"loadingModel": "Caricamento del modello",
|
||||
"warnings": "Avvisi"
|
||||
},
|
||||
"gallery": {
|
||||
"galleryImageSize": "Dimensione dell'immagine",
|
||||
@@ -671,11 +672,15 @@
|
||||
"ipAdapterIncompatibleBaseModel": "Il modello base dell'adattatore IP non è compatibile",
|
||||
"ipAdapterNoImageSelected": "Nessuna immagine dell'adattatore IP selezionata",
|
||||
"rgNoPromptsOrIPAdapters": "Nessun prompt o adattatore IP",
|
||||
"rgNoRegion": "Nessuna regione selezionata",
|
||||
"t2iAdapterIncompatibleBboxWidth": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}}, larghezza riquadro è {{width}}",
|
||||
"t2iAdapterIncompatibleBboxHeight": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}}, altezza riquadro è {{height}}",
|
||||
"t2iAdapterIncompatibleScaledBboxWidth": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}}, larghezza del riquadro scalato {{width}}",
|
||||
"t2iAdapterIncompatibleScaledBboxHeight": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}}, altezza del riquadro scalato {{height}}"
|
||||
"t2iAdapterIncompatibleScaledBboxHeight": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}}, altezza del riquadro scalato {{height}}",
|
||||
"rgNegativePromptNotSupported": "prompt negativo non supportato per il modello base selezionato",
|
||||
"rgAutoNegativeNotSupported": "auto-negativo non supportato per il modello base selezionato",
|
||||
"emptyLayer": "livello vuoto",
|
||||
"unsupportedModel": "livello non supportato per il modello base selezionato",
|
||||
"rgReferenceImagesNotSupported": "immagini di riferimento regionali non supportate per il modello base selezionato"
|
||||
},
|
||||
"fluxModelIncompatibleBboxHeight": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), altezza riquadro è {{height}}",
|
||||
"fluxModelIncompatibleBboxWidth": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), larghezza riquadro è {{width}}",
|
||||
@@ -687,7 +692,11 @@
|
||||
"canvasIsTransforming": "La tela sta trasformando",
|
||||
"canvasIsRasterizing": "La tela sta rasterizzando",
|
||||
"canvasIsCompositing": "La tela è in fase di composizione",
|
||||
"canvasIsFiltering": "La tela sta filtrando"
|
||||
"canvasIsFiltering": "La tela sta filtrando",
|
||||
"collectionTooManyItems": "{{nodeLabel}} -> {{fieldLabel}}: troppi elementi, massimo {{maxItems}}",
|
||||
"canvasIsSelectingObject": "La tela è occupata (selezione dell'oggetto)",
|
||||
"collectionTooFewItems": "{{nodeLabel}} -> {{fieldLabel}}: troppi pochi elementi, minimo {{minItems}}",
|
||||
"collectionEmpty": "{{nodeLabel}} -> {{fieldLabel}} raccolta vuota"
|
||||
},
|
||||
"useCpuNoise": "Usa la CPU per generare rumore",
|
||||
"iterations": "Iterazioni",
|
||||
@@ -972,7 +981,9 @@
|
||||
"saveToGallery": "Salva nella Galleria",
|
||||
"noMatchingWorkflows": "Nessun flusso di lavoro corrispondente",
|
||||
"noWorkflows": "Nessun flusso di lavoro",
|
||||
"workflowHelpText": "Hai bisogno di aiuto? Consulta la nostra guida <LinkComponent>Introduzione ai flussi di lavoro</LinkComponent>."
|
||||
"workflowHelpText": "Hai bisogno di aiuto? Consulta la nostra guida <LinkComponent>Introduzione ai flussi di lavoro</LinkComponent>.",
|
||||
"specialDesc": "Questa invocazione comporta una gestione speciale nell'applicazione. Ad esempio, i nodi Lotto vengono utilizzati per mettere in coda più grafici da un singolo flusso di lavoro.",
|
||||
"internalDesc": "Questa invocazione è utilizzata internamente da Invoke. Potrebbe subire modifiche significative durante gli aggiornamenti dell'app e potrebbe essere rimossa in qualsiasi momento."
|
||||
},
|
||||
"boards": {
|
||||
"autoAddBoard": "Aggiungi automaticamente bacheca",
|
||||
@@ -1093,7 +1104,8 @@
|
||||
"workflows": "Flussi di lavoro",
|
||||
"generation": "Generazione",
|
||||
"other": "Altro",
|
||||
"gallery": "Galleria"
|
||||
"gallery": "Galleria",
|
||||
"batchSize": "Dimensione del lotto"
|
||||
},
|
||||
"models": {
|
||||
"noMatchingModels": "Nessun modello corrispondente",
|
||||
@@ -1196,7 +1208,8 @@
|
||||
"heading": "Percentuale passi Inizio / Fine",
|
||||
"paragraphs": [
|
||||
"La parte del processo di rimozione del rumore in cui verrà applicato l'adattatore di controllo.",
|
||||
"In genere, gli adattatori di controllo applicati all'inizio del processo guidano la composizione, mentre quelli applicati alla fine guidano i dettagli."
|
||||
"In genere, gli adattatori di controllo applicati all'inizio del processo guidano la composizione, mentre quelli applicati alla fine guidano i dettagli.",
|
||||
"• Passo finale (%): specifica quando interrompere l'applicazione della guida di questo livello e ripristinare la guida generale dal modello e altre impostazioni."
|
||||
]
|
||||
},
|
||||
"noiseUseCPU": {
|
||||
@@ -1300,7 +1313,9 @@
|
||||
"controlNetWeight": {
|
||||
"heading": "Peso",
|
||||
"paragraphs": [
|
||||
"Peso dell'adattatore di controllo. Un peso maggiore porterà a impatti maggiori sull'immagine finale."
|
||||
"Regola la forza con cui il livello influenza il processo di generazione",
|
||||
"• Peso maggiore (0.75-2): crea un impatto più significativo sul risultato finale.",
|
||||
"• Peso inferiore (0-0.75): crea un impatto minore sul risultato finale."
|
||||
]
|
||||
},
|
||||
"paramCFGScale": {
|
||||
@@ -1801,7 +1816,10 @@
|
||||
"full": "Stile e Composizione",
|
||||
"style": "Solo Stile",
|
||||
"composition": "Solo Composizione",
|
||||
"ipAdapterMethod": "Metodo Adattatore IP"
|
||||
"ipAdapterMethod": "Metodo Adattatore IP",
|
||||
"fullDesc": "Applica lo stile visivo (colori, texture) e la composizione (disposizione, struttura).",
|
||||
"styleDesc": "Applica lo stile visivo (colori, texture) senza considerare la disposizione.",
|
||||
"compositionDesc": "Replica disposizione e struttura ignorando lo stile di riferimento."
|
||||
},
|
||||
"showingType": "Mostra {{type}}",
|
||||
"dynamicGrid": "Griglia dinamica",
|
||||
@@ -2044,7 +2062,16 @@
|
||||
"replaceCurrent": "Sostituisci corrente",
|
||||
"mergeDown": "Unire in basso",
|
||||
"mergingLayers": "Unione dei livelli",
|
||||
"controlLayerEmptyState": "<UploadButton>Carica un'immagine</UploadButton>, trascina un'immagine dalla <GalleryButton>galleria</GalleryButton> su questo livello oppure disegna sulla tela per iniziare."
|
||||
"controlLayerEmptyState": "<UploadButton>Carica un'immagine</UploadButton>, trascina un'immagine dalla <GalleryButton>galleria</GalleryButton> su questo livello oppure disegna sulla tela per iniziare.",
|
||||
"useImage": "Usa immagine",
|
||||
"resetGenerationSettings": "Ripristina impostazioni di generazione",
|
||||
"referenceImageEmptyState": "Per iniziare, <UploadButton>carica un'immagine</UploadButton> oppure trascina un'immagine dalla <GalleryButton>galleria</GalleryButton> su questo livello.",
|
||||
"asRasterLayer": "Come $t(controlLayers.rasterLayer)",
|
||||
"asRasterLayerResize": "Come $t(controlLayers.rasterLayer) (Ridimensiona)",
|
||||
"asControlLayer": "Come $t(controlLayers.controlLayer)",
|
||||
"asControlLayerResize": "Come $t(controlLayers.controlLayer) (Ridimensiona)",
|
||||
"newSession": "Nuova sessione",
|
||||
"resetCanvasLayers": "Ripristina livelli Tela"
|
||||
},
|
||||
"ui": {
|
||||
"tabs": {
|
||||
@@ -2144,7 +2171,7 @@
|
||||
"watchRecentReleaseVideos": "Guarda i video su questa versione",
|
||||
"watchUiUpdatesOverview": "Guarda le novità dell'interfaccia",
|
||||
"items": [
|
||||
"<StrongComponent>SD 3.5</StrongComponent>: supporto per SD 3.5 Medium e Large.",
|
||||
"<StrongComponent>Flussi di lavoro</StrongComponent>: esegui un flusso di lavoro per una raccolta di immagini utilizzando il nuovo nodo <StrongComponent>Lotto di immagini</StrongComponent>.",
|
||||
"<StrongComponent>Tela</StrongComponent>: elaborazione semplificata del livello di controllo e impostazioni di controllo predefinite migliorate."
|
||||
]
|
||||
},
|
||||
@@ -2172,5 +2199,67 @@
|
||||
"logNamespaces": "Elementi del registro"
|
||||
},
|
||||
"enableLogging": "Abilita la registrazione"
|
||||
},
|
||||
"supportVideos": {
|
||||
"gettingStarted": "Iniziare",
|
||||
"supportVideos": "Video di supporto",
|
||||
"videos": {
|
||||
"usingControlLayersAndReferenceGuides": {
|
||||
"title": "Utilizzo di livelli di controllo e guide di riferimento",
|
||||
"description": "Scopri come guidare la creazione delle tue immagini con livelli di controllo e immagini di riferimento."
|
||||
},
|
||||
"creatingYourFirstImage": {
|
||||
"description": "Introduzione alla creazione di un'immagine da zero utilizzando gli strumenti di Invoke.",
|
||||
"title": "Creazione della tua prima immagine"
|
||||
},
|
||||
"understandingImageToImageAndDenoising": {
|
||||
"description": "Panoramica delle trasformazioni immagine-a-immagine e della riduzione del rumore in Invoke.",
|
||||
"title": "Comprendere immagine-a-immagine e riduzione del rumore"
|
||||
},
|
||||
"howDoIDoImageToImageTransformation": {
|
||||
"description": "Tutorial su come eseguire trasformazioni da immagine a immagine in Invoke.",
|
||||
"title": "Come si esegue la trasformazione da immagine-a-immagine?"
|
||||
},
|
||||
"howDoIUseInpaintMasks": {
|
||||
"title": "Come si usano le maschere Inpaint?",
|
||||
"description": "Come applicare maschere inpaint per la correzione e la variazione delle immagini."
|
||||
},
|
||||
"howDoIOutpaint": {
|
||||
"description": "Guida all'outpainting oltre i confini dell'immagine originale.",
|
||||
"title": "Come posso eseguire l'outpainting?"
|
||||
},
|
||||
"exploringAIModelsAndConceptAdapters": {
|
||||
"description": "Approfondisci i modelli di intelligenza artificiale e scopri come utilizzare gli adattatori concettuali per il controllo creativo.",
|
||||
"title": "Esplorazione dei modelli di IA e degli adattatori concettuali"
|
||||
},
|
||||
"upscaling": {
|
||||
"title": "Ampliamento",
|
||||
"description": "Come ampliare le immagini con gli strumenti di Invoke per migliorarne la risoluzione."
|
||||
},
|
||||
"creatingAndComposingOnInvokesControlCanvas": {
|
||||
"description": "Impara a comporre immagini utilizzando la tela di controllo di Invoke.",
|
||||
"title": "Creare e comporre sulla tela di controllo di Invoke"
|
||||
},
|
||||
"howDoIGenerateAndSaveToTheGallery": {
|
||||
"description": "Passaggi per generare e salvare le immagini nella galleria.",
|
||||
"title": "Come posso generare e salvare nella Galleria?"
|
||||
},
|
||||
"howDoIEditOnTheCanvas": {
|
||||
"title": "Come posso apportare modifiche sulla tela?",
|
||||
"description": "Guida alla modifica delle immagini direttamente sulla tela."
|
||||
},
|
||||
"howDoIUseControlNetsAndControlLayers": {
|
||||
"title": "Come posso utilizzare le Reti di Controllo e i Livelli di Controllo?",
|
||||
"description": "Impara ad applicare livelli di controllo e reti di controllo alle tue immagini."
|
||||
},
|
||||
"howDoIUseGlobalIPAdaptersAndReferenceImages": {
|
||||
"title": "Come si utilizzano gli adattatori IP globali e le immagini di riferimento?",
|
||||
"description": "Introduzione all'aggiunta di immagini di riferimento e adattatori IP globali."
|
||||
}
|
||||
},
|
||||
"controlCanvas": "Tela di Controllo",
|
||||
"watch": "Guarda",
|
||||
"studioSessionsDesc1": "Dai un'occhiata a <StudioSessionsPlaylistLink /> per approfondimenti su Invoke.",
|
||||
"studioSessionsDesc2": "Unisciti al nostro <DiscordLink /> per partecipare alle sessioni live e fare domande. Le sessioni vengono caricate sulla playlist la settimana successiva."
|
||||
}
|
||||
}
|
||||
|
||||
@@ -236,7 +236,6 @@
|
||||
"controlAdapterIncompatibleBaseModel": "niet-compatibele basismodel voor controle-adapter",
|
||||
"ipAdapterIncompatibleBaseModel": "niet-compatibele basismodel voor IP-adapter",
|
||||
"ipAdapterNoImageSelected": "geen afbeelding voor IP-adapter geselecteerd",
|
||||
"rgNoRegion": "geen gebied geselecteerd",
|
||||
"rgNoPromptsOrIPAdapters": "geen tekstprompts of IP-adapters",
|
||||
"ipAdapterNoModelSelected": "geen IP-adapter geselecteerd"
|
||||
}
|
||||
|
||||
@@ -10,7 +10,24 @@
|
||||
"load": "Załaduj",
|
||||
"statusDisconnected": "Odłączono od serwera",
|
||||
"githubLabel": "GitHub",
|
||||
"discordLabel": "Discord"
|
||||
"discordLabel": "Discord",
|
||||
"clipboard": "Schowek",
|
||||
"aboutDesc": "Wykorzystujesz Invoke do pracy? Sprawdź:",
|
||||
"ai": "SI",
|
||||
"areYouSure": "Czy jesteś pewien?",
|
||||
"copyError": "$t(gallery.copy) Błąd",
|
||||
"apply": "Zastosuj",
|
||||
"copy": "Kopiuj",
|
||||
"or": "albo",
|
||||
"add": "Dodaj",
|
||||
"off": "Wyłączony",
|
||||
"accept": "Zaakceptuj",
|
||||
"cancel": "Anuluj",
|
||||
"advanced": "Zawansowane",
|
||||
"back": "Do tyłu",
|
||||
"auto": "Automatyczny",
|
||||
"beta": "Beta",
|
||||
"close": "Wyjdź"
|
||||
},
|
||||
"gallery": {
|
||||
"galleryImageSize": "Rozmiar obrazów",
|
||||
@@ -65,6 +82,42 @@
|
||||
"uploadImage": "Wgrywanie obrazu",
|
||||
"previousImage": "Poprzedni obraz",
|
||||
"nextImage": "Następny obraz",
|
||||
"menu": "Menu"
|
||||
"menu": "Menu",
|
||||
"mode": "Tryb"
|
||||
},
|
||||
"boards": {
|
||||
"cancel": "Anuluj",
|
||||
"noBoards": "Brak tablic typu {{boardType}}",
|
||||
"imagesWithCount_one": "{{count}} zdjęcie",
|
||||
"imagesWithCount_few": "{{count}} zdjęcia",
|
||||
"imagesWithCount_many": "{{count}} zdjęcia",
|
||||
"private": "Prywatne tablice",
|
||||
"updateBoardError": "Błąd aktualizacji tablicy",
|
||||
"uncategorized": "Nieskategoryzowane",
|
||||
"selectBoard": "Wybierz tablicę",
|
||||
"downloadBoard": "Pobierz tablice",
|
||||
"loading": "Ładowanie...",
|
||||
"move": "Przenieś",
|
||||
"noMatching": "Brak pasujących tablic"
|
||||
},
|
||||
"accordions": {
|
||||
"compositing": {
|
||||
"title": "Kompozycja",
|
||||
"infillTab": "Inskrypcja",
|
||||
"coherenceTab": "Przebieg Koherencji"
|
||||
},
|
||||
"generation": {
|
||||
"title": "Generowanie"
|
||||
},
|
||||
"image": {
|
||||
"title": "Zdjęcie"
|
||||
},
|
||||
"advanced": {
|
||||
"options": "$t(accordions.advanced.title) Opcje",
|
||||
"title": "Zaawansowane"
|
||||
},
|
||||
"control": {
|
||||
"title": "Kontrola"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -652,7 +652,6 @@
|
||||
"ipAdapterNoModelSelected": "IP адаптер не выбран",
|
||||
"controlAdapterNoModelSelected": "не выбрана модель адаптера контроля",
|
||||
"controlAdapterIncompatibleBaseModel": "несовместимая базовая модель адаптера контроля",
|
||||
"rgNoRegion": "регион не выбран",
|
||||
"rgNoPromptsOrIPAdapters": "нет текстовых запросов или IP-адаптеров",
|
||||
"ipAdapterIncompatibleBaseModel": "несовместимая базовая модель IP-адаптера",
|
||||
"ipAdapterNoImageSelected": "изображение IP-адаптера не выбрано",
|
||||
|
||||
@@ -217,7 +217,10 @@
|
||||
"direction": "Phương Hướng",
|
||||
"unknownError": "Lỗi Không Rõ",
|
||||
"selected": "Đã chọn",
|
||||
"tab": "Tab"
|
||||
"tab": "Tab",
|
||||
"loadingModel": "Đang Tải Model",
|
||||
"generating": "Đang Tạo Sinh",
|
||||
"warnings": "Cảnh Báo"
|
||||
},
|
||||
"prompt": {
|
||||
"addPromptTrigger": "Thêm Prompt Trigger",
|
||||
@@ -290,7 +293,8 @@
|
||||
"cancelSucceeded": "Mục Đã Huỷ Bỏ",
|
||||
"completedIn": "Hoàn tất trong",
|
||||
"graphQueued": "Đồ Thị Đã Vào Hàng",
|
||||
"batchQueuedDesc_other": "Thêm {{count}} phiên vào {{direction}} của hàng"
|
||||
"batchQueuedDesc_other": "Thêm {{count}} phiên vào {{direction}} của hàng",
|
||||
"batchSize": "Kích Thước Vùng Hàng Loạt"
|
||||
},
|
||||
"hotkeys": {
|
||||
"canvas": {
|
||||
@@ -733,7 +737,9 @@
|
||||
"textualInversions": "Bộ Đảo Ngược Văn Bản",
|
||||
"loraTriggerPhrases": "Từ Ngữ Kích Hoạt Cho LoRA",
|
||||
"width": "Chiều Rộng",
|
||||
"starterModelsInModelManager": "Model khởi đầu có thể tìm thấy ở Trình Quản Lý Model"
|
||||
"starterModelsInModelManager": "Model khởi đầu có thể tìm thấy ở Trình Quản Lý Model",
|
||||
"clipLEmbed": "CLIP-L Embed",
|
||||
"clipGEmbed": "CLIP-G Embed"
|
||||
},
|
||||
"metadata": {
|
||||
"guidance": "Hướng Dẫn",
|
||||
@@ -905,7 +911,7 @@
|
||||
"unknownNode": "Node Không Rõ",
|
||||
"unknownNodeType": "Loại Node Không Rõ",
|
||||
"unknownTemplate": "Mẫu Trình Bày Không Rõ",
|
||||
"cannotConnectOutputToOutput": "Không thế kết nối đầu ra với đầu vào",
|
||||
"cannotConnectOutputToOutput": "Không thế kết nối đầu ra với đầu ra",
|
||||
"cannotConnectToSelf": "Không thể kết nối với chính nó",
|
||||
"workflow": "Workflow",
|
||||
"addNodeToolTip": "Thêm Node (Shift+A, Space)",
|
||||
@@ -952,7 +958,9 @@
|
||||
"executionStateInProgress": "Đang Xử Lý",
|
||||
"showLegendNodes": "Hiển Thị Vùng Nhập",
|
||||
"outputFieldTypeParseError": "Không thể phân tích loại dữ liệu đầu ra của {{node}}.{{field}} ({{message}})",
|
||||
"modelAccessError": "Không thể tìm thấy model {{key}}, chuyển về mặc định"
|
||||
"modelAccessError": "Không thể tìm thấy model {{key}}, chuyển về mặc định",
|
||||
"internalDesc": "Trình kích hoạt này được dùng bên trong bởi Invoke. Nó có thể phá hỏng thay đổi trong khi cập nhật ứng dụng và có thể bị xoá bất cứ lúc nào.",
|
||||
"specialDesc": "Trình kích hoạt này có một số xử lý đặc biệt trong ứng dụng. Ví dụ, Node Hàng Loạt được dùng để xếp vào nhiều đồ thị từ một workflow."
|
||||
},
|
||||
"popovers": {
|
||||
"paramCFGRescaleMultiplier": {
|
||||
@@ -1105,7 +1113,9 @@
|
||||
},
|
||||
"controlNetWeight": {
|
||||
"paragraphs": [
|
||||
"Trọng lượng của Control Adapter. Trọng lượng càng cao sẽ dẫn đến tác động càng lớn lên ảnh cuối cùng."
|
||||
"Điều chỉnh mức độ layer ảnh hưởng đến quá trình xử lý tạo sinh.",
|
||||
"• Trọng Lượng Lớn Hơn (.75-2): Gây ra ảnh hưởng lớn hơn lên kết quả cuối cùng.",
|
||||
"• Trọng Lượng Nhỏ Hơn (0-.75): Gây ra ảnh hưởng nhỏ hơn lên kết quả cuối cùng."
|
||||
],
|
||||
"heading": "Trọng Lượng"
|
||||
},
|
||||
@@ -1149,7 +1159,7 @@
|
||||
},
|
||||
"ipAdapterMethod": {
|
||||
"paragraphs": [
|
||||
"Cách thức dùng để áp dụng IP Adapter hiện tại."
|
||||
"Phương thức định nghĩa cách ảnh mẫu sẽ chỉ dẫn quá trình xử lý tạo sinh."
|
||||
],
|
||||
"heading": "Cách Thức"
|
||||
},
|
||||
@@ -1196,8 +1206,9 @@
|
||||
},
|
||||
"controlNetBeginEnd": {
|
||||
"paragraphs": [
|
||||
"Một phần trong quá trình xử lý khử nhiễu mà sẽ được Control Adapter áp dụng.",
|
||||
"Nói chung, Control Adapter áp dụng vào lúc bắt đầu của quá trình hướng dẫn thành phần, và cũng áp dụng vào lúc kết thúc hướng dẫn chi tiết."
|
||||
"Cài đặt này xác định phần xử lý khử nhiễu (trong khi tạo sinh) kết hợp với chỉ dẫn từ layer này.",
|
||||
"• Bước Bắt Đầu (%): Chỉ định lúc bắt đầu áp dụng chỉ dẫn từ layer này trong quá trình tạo sinh.",
|
||||
"• Bước Kết Thúc (%): Chỉ định lúc dừng áp dụng chỉ dẫn của layer này và trở về chỉ dẫn chung từ model và các thiết lập khác."
|
||||
],
|
||||
"heading": "Phần Trăm Tham Số Bước Khi Bắt Đầu/Kết Thúc"
|
||||
},
|
||||
@@ -1401,7 +1412,6 @@
|
||||
"invoke": {
|
||||
"layer": {
|
||||
"t2iAdapterIncompatibleScaledBboxHeight": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}}, tỉ lệ chiều dài hộp giới hạn là {{height}}",
|
||||
"rgNoRegion": "không có vùng được chọn",
|
||||
"ipAdapterNoModelSelected": "không có IP Adapter được lựa chọn",
|
||||
"ipAdapterNoImageSelected": "không có ảnh IP Adapter được lựa chọn",
|
||||
"t2iAdapterIncompatibleBboxHeight": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}}, chiều dài hộp giới hạn là {{height}}",
|
||||
@@ -1410,15 +1420,20 @@
|
||||
"rgNoPromptsOrIPAdapters": "không có lệnh chữ hoặc IP Adapter",
|
||||
"controlAdapterIncompatibleBaseModel": "model cơ sở của Control Adapter không tương thích",
|
||||
"ipAdapterIncompatibleBaseModel": "dạng model cơ sở của IP Adapter không tương thích",
|
||||
"controlAdapterNoModelSelected": "không có model Control Adapter được chọn"
|
||||
"controlAdapterNoModelSelected": "không có model Control Adapter được chọn",
|
||||
"emptyLayer": "layer trống",
|
||||
"rgAutoNegativeNotSupported": "trình tự động đảo chiều không được hỗ trợ cho model cơ sở đang dùng",
|
||||
"rgNegativePromptNotSupported": "lệnh tiêu cực không được hỗ trợ cho model cơ sở đang dùng",
|
||||
"unsupportedModel": "layer không được hỗ trợ cho model cơ sở đang dùng",
|
||||
"rgReferenceImagesNotSupported": "ảnh mẫu khu vực không được hỗ trợ cho model cơ sở đang dùng"
|
||||
},
|
||||
"fluxModelIncompatibleBboxWidth": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), chiều rộng hộp giới hạn là {{width}}",
|
||||
"noModelSelected": "Không có model được lựa chọn",
|
||||
"fluxModelIncompatibleScaledBboxHeight": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), tỉ lệ chiều dài hộp giới hạn là {{height}}",
|
||||
"canvasIsFiltering": "Canvas đang được lọc",
|
||||
"canvasIsRasterizing": "Canvas đang được raster hoá",
|
||||
"canvasIsTransforming": "Canvas đang được biến đổi",
|
||||
"canvasIsCompositing": "Canvas đang được kết hợp",
|
||||
"canvasIsFiltering": "Canvas đang bận (đang lọc)",
|
||||
"canvasIsRasterizing": "Canvas đang bận (đang raster hoá)",
|
||||
"canvasIsTransforming": "Canvas đang bận (đang biến đổi)",
|
||||
"canvasIsCompositing": "Canvas đang bận (đang kết hợp)",
|
||||
"noPrompts": "Không có lệnh được tạo",
|
||||
"noNodesInGraph": "Không có node trong đồ thị",
|
||||
"addingImagesTo": "Thêm ảnh vào",
|
||||
@@ -1430,8 +1445,12 @@
|
||||
"missingNodeTemplate": "Thiếu mẫu trình bày node",
|
||||
"fluxModelIncompatibleBboxHeight": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), chiều dài hộp giới hạn là {{height}}",
|
||||
"fluxModelIncompatibleScaledBboxWidth": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), tỉ lệ chiều rộng hộp giới hạn là {{width}}",
|
||||
"missingInputForField": "{{nodeLabel}} -> {{fieldLabel}} thiếu đầu ra",
|
||||
"missingFieldTemplate": "Thiếu vùng mẫu trình bày"
|
||||
"missingInputForField": "{{nodeLabel}} -> {{fieldLabel}}: thiếu đầu vào",
|
||||
"missingFieldTemplate": "Thiếu vùng mẫu trình bày",
|
||||
"collectionEmpty": "{{nodeLabel}} -> {{fieldLabel}} tài nguyên trống",
|
||||
"collectionTooFewItems": "{{nodeLabel}} -> {{fieldLabel}}: quá ít mục, tối thiểu {{minItems}}",
|
||||
"collectionTooManyItems": "{{nodeLabel}} -> {{fieldLabel}}: quá nhiều mục, tối đa {{maxItems}}",
|
||||
"canvasIsSelectingObject": "Canvas đang bận (đang chọn đồ vật)"
|
||||
},
|
||||
"cfgScale": "Thước Đo CFG",
|
||||
"useSeed": "Dùng Tham Số Hạt Giống",
|
||||
@@ -1542,7 +1561,8 @@
|
||||
"resetWebUIDesc2": "Nếu ảnh không được xuất hiện trong thư viện hoặc điều gì đó không ổn đang diễn ra, hãy thử khởi động lại trước khi báo lỗi trên Github.",
|
||||
"displayInProgress": "Hiển Thị Hình Ảnh Đang Xử Lý",
|
||||
"intermediatesClearedFailed": "Có Vấn Đề Khi Dọn Sạch Sản Phẩm Trung Gian",
|
||||
"enableInvisibleWatermark": "Bật Chế Độ Ẩn Watermark"
|
||||
"enableInvisibleWatermark": "Bật Chế Độ Ẩn Watermark",
|
||||
"showDetailedInvocationProgress": "Hiện Dữ Liệu Xử Lý"
|
||||
},
|
||||
"sdxl": {
|
||||
"loading": "Đang Tải...",
|
||||
@@ -1594,7 +1614,7 @@
|
||||
"pullBboxIntoLayerError": "Có Vấn Đề Khi Chuyển Hộp Giới Hạn Thành Layer",
|
||||
"pullBboxIntoReferenceImageOk": "Chuyển Hộp Giới Hạn Thành Ảnh Mẫu",
|
||||
"clearCaches": "Xoá Bộ Nhớ Đệm",
|
||||
"outputOnlyMaskedRegions": "Chỉ Xuất Đầu Ra Ở Vùng Phủ",
|
||||
"outputOnlyMaskedRegions": "Chỉ Xuất Đầu Ra Ở Vùng Tạo Sinh",
|
||||
"addLayer": "Thêm Layer",
|
||||
"regional": "Khu Vực",
|
||||
"regionIsEmpty": "Vùng được chọn trống",
|
||||
@@ -1608,10 +1628,13 @@
|
||||
"moveForward": "Chuyển Lên Đầu",
|
||||
"fitBboxToLayers": "Xếp Vừa Hộp Giới Hạn Vào Layer",
|
||||
"ipAdapterMethod": {
|
||||
"full": "Đầy Đủ",
|
||||
"full": "Phong Cách Và Thành Phần",
|
||||
"style": "Chỉ Lấy Phong Cách",
|
||||
"composition": "Chỉ Lấy Thành Phần",
|
||||
"ipAdapterMethod": "Cách Thức IP Adapter"
|
||||
"ipAdapterMethod": "Cách Thức",
|
||||
"compositionDesc": "Áp dụng cách trình bày và bỏ qua phong cách mẫu.",
|
||||
"fullDesc": "Áp dụng phong cách trực quan (màu, cấu tạo) & thành phần (cách trình bày).",
|
||||
"styleDesc": "Áp dụng phong cách trực quan (màu, cấu tạo) và bỏ qua cách trình bày."
|
||||
},
|
||||
"deletePrompt": "Xoá Lệnh",
|
||||
"rasterLayer": "Layer Dạng Raster",
|
||||
@@ -1899,7 +1922,16 @@
|
||||
"colorPicker": "Chọn Màu"
|
||||
},
|
||||
"mergingLayers": "Đang gộp layer",
|
||||
"controlLayerEmptyState": "<UploadButton>Tải lên ảnh</UploadButton>, kéo thả ảnh từ <GalleryButton>thư viện</GalleryButton> vào layer này, hoặc vẽ trên canvas để bắt đầu."
|
||||
"controlLayerEmptyState": "<UploadButton>Tải lên ảnh</UploadButton>, kéo thả ảnh từ <GalleryButton>thư viện</GalleryButton> vào layer này, hoặc vẽ trên canvas để bắt đầu.",
|
||||
"referenceImageEmptyState": "<UploadButton>Tải lên ảnh</UploadButton> hoặc kéo thả ảnh từ <GalleryButton>thư viện</GalleryButton> vào layer này để bắt đầu.",
|
||||
"useImage": "Dùng Hình Ảnh",
|
||||
"resetCanvasLayers": "Khởi Động Lại Layer Canvas",
|
||||
"asRasterLayer": "Như $t(controlLayers.rasterLayer)",
|
||||
"asRasterLayerResize": "Như $t(controlLayers.rasterLayer) (Thay Đổi Kích Thước)",
|
||||
"asControlLayer": "Như $t(controlLayers.controlLayer)",
|
||||
"asControlLayerResize": "Như $t(controlLayers.controlLayer) (Thay Đổi Kích Thước)",
|
||||
"newSession": "Phiên Làm Việc Mới",
|
||||
"resetGenerationSettings": "Khởi Động Lại Cài Đặt Tạo Sinh"
|
||||
},
|
||||
"stylePresets": {
|
||||
"negativePrompt": "Lệnh Tiêu Cực",
|
||||
@@ -2124,8 +2156,8 @@
|
||||
"watchRecentReleaseVideos": "Xem Video Phát Hành Mới Nhất",
|
||||
"watchUiUpdatesOverview": "Xem Tổng Quan Về Những Cập Nhật Cho Giao Diện Người Dùng",
|
||||
"items": [
|
||||
"<StrongComponent>SD 3.5</StrongComponent>: Hỗ trợ cho Từ ngữ Sang Hình Ảnh trong Workflow với phiên bản SD 3.5 Medium hoặc Large.",
|
||||
"<StrongComponent>Canvas</StrongComponent>: Hợp lý hoá cách xử lý Layer Điều Khiển Được và cải thiện thiết lập điều khiển mặc định."
|
||||
"<StrongComponent>Workflows</StrongComponent>: Chạy một workflow cho nhiều ảnh bằng node <StrongComponent>Ảnh Hàng Loạt</StrongComponent> mới.",
|
||||
"<StrongComponent>FLUX</StrongComponent>: Hỗ trợ cho XLabs IP Adapter v2."
|
||||
]
|
||||
},
|
||||
"upsell": {
|
||||
@@ -2133,5 +2165,67 @@
|
||||
"inviteTeammates": "Thêm Đồng Đội",
|
||||
"shareAccess": "Chia Sẻ Quyền Truy Cập",
|
||||
"professionalUpsell": "Không có sẵn Phiên Bản Chuyên Nghiệp cho Invoke. Bấm vào đây hoặc đến invoke.com/pricing để thêm chi tiết."
|
||||
},
|
||||
"supportVideos": {
|
||||
"supportVideos": "Video Hỗ Trợ",
|
||||
"gettingStarted": "Bắt Đầu Làm Quen",
|
||||
"studioSessionsDesc1": "Xem thử <StudioSessionsPlaylistLink /> để hiểu rõ Invoke hơn.",
|
||||
"studioSessionsDesc2": "Đến <DiscordLink /> để tham gia vào phiên trực tiếp và hỏi câu hỏi. Các phiên được tải lên danh sách phát vào các tuần.",
|
||||
"videos": {
|
||||
"howDoIDoImageToImageTransformation": {
|
||||
"title": "Làm Sao Để Tôi Dùng Trình Biến Đổi Hình Ảnh Sang Hình Ảnh?",
|
||||
"description": "Hướng dẫn cách thực hiện biến đổi ảnh sang ảnh trong Invoke."
|
||||
},
|
||||
"howDoIUseGlobalIPAdaptersAndReferenceImages": {
|
||||
"description": "Giới thiệu về ảnh mẫu và IP adapter toàn vùng.",
|
||||
"title": "Làm Sao Để Tôi Dùng IP Adapter Toàn Vùng Và Ảnh Mẫu?"
|
||||
},
|
||||
"creatingAndComposingOnInvokesControlCanvas": {
|
||||
"description": "Học cách sáng tạo ảnh bằng trình điều khiển canvas của Invoke.",
|
||||
"title": "Sáng Tạo Trong Trình Kiểm Soát Canvas Của Invoke"
|
||||
},
|
||||
"upscaling": {
|
||||
"description": "Cách upscale ảnh bằng bộ công cụ của Invoke để nâng cấp độ phân giải.",
|
||||
"title": "Upscale (Nâng Cấp Chất Lượng Hình Ảnh)"
|
||||
},
|
||||
"howDoIGenerateAndSaveToTheGallery": {
|
||||
"title": "Làm Sao Để Tôi Tạo Sinh Và Lưu Vào Thư Viện?",
|
||||
"description": "Các bước để tạo sinh và lưu ảnh vào thư viện."
|
||||
},
|
||||
"howDoIEditOnTheCanvas": {
|
||||
"description": "Hướng dẫn chỉnh sửa ảnh trực tiếp trên canvas.",
|
||||
"title": "Làm Sao Để Tôi Chỉnh Sửa Trên Canvas?"
|
||||
},
|
||||
"howDoIUseControlNetsAndControlLayers": {
|
||||
"title": "Làm Sao Để Tôi Dùng ControlNet và Layer Điều Khiển Được?",
|
||||
"description": "Học cách áp dụng layer điều khiển được và controlnet vào ảnh của bạn."
|
||||
},
|
||||
"howDoIUseInpaintMasks": {
|
||||
"title": "Làm Sao Để Tôi Dùng Lớp Phủ Inpaint?",
|
||||
"description": "Cách áp dụng lớp phủ inpaint vào chỉnh sửa và thay đổi ảnh."
|
||||
},
|
||||
"howDoIOutpaint": {
|
||||
"title": "Làm Sao Để Tôi Outpaint?",
|
||||
"description": "Hướng dẫn outpaint bên ngoài viền ảnh gốc."
|
||||
},
|
||||
"creatingYourFirstImage": {
|
||||
"description": "Giới thiệu về cách tạo ảnh từ ban đầu bằng công cụ Invoke.",
|
||||
"title": "Tạo Hình Ảnh Đầu Tiên Của Bạn"
|
||||
},
|
||||
"usingControlLayersAndReferenceGuides": {
|
||||
"description": "Học cách chỉ dẫn ảnh được tạo ra bằng layer điều khiển được và ảnh mẫu.",
|
||||
"title": "Dùng Layer Điều Khiển Được và Chỉ Dẫn Mẫu"
|
||||
},
|
||||
"understandingImageToImageAndDenoising": {
|
||||
"title": "Hiểu Rõ Trình Hình Ảnh Sang Hình Ảnh Và Trình Khử Nhiễu",
|
||||
"description": "Tổng quan về trình biến đổi ảnh sang ảnh và trình khử nhiễu trong Invoke."
|
||||
},
|
||||
"exploringAIModelsAndConceptAdapters": {
|
||||
"title": "Khám Phá Model AI Và Khái Niệm Về Adapter",
|
||||
"description": "Đào sâu vào model AI và cách dùng những adapter để điều khiển một cách sáng tạo."
|
||||
}
|
||||
},
|
||||
"controlCanvas": "Điều Khiển Canvas",
|
||||
"watch": "Xem"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -668,7 +668,6 @@
|
||||
"controlAdapterIncompatibleBaseModel": "Control Adapter的基础模型不兼容",
|
||||
"ipAdapterIncompatibleBaseModel": "IP Adapter的基础模型不兼容",
|
||||
"ipAdapterNoImageSelected": "未选择IP Adapter图像",
|
||||
"rgNoRegion": "未选择区域",
|
||||
"t2iAdapterIncompatibleBboxWidth": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}},边界框宽度为 {{width}}",
|
||||
"t2iAdapterIncompatibleScaledBboxHeight": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}},缩放后的边界框高度为 {{height}}",
|
||||
"t2iAdapterIncompatibleBboxHeight": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}},边界框高度为 {{height}}",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useClearQueue } from 'features/queue/components/ClearQueueConfirmationAlertDialog';
|
||||
import { useCancelCurrentQueueItem } from 'features/queue/hooks/useCancelCurrentQueueItem';
|
||||
import { useClearQueue } from 'features/queue/hooks/useClearQueue';
|
||||
import { useInvoke } from 'features/queue/hooks/useInvoke';
|
||||
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
|
||||
@@ -63,7 +63,7 @@ export const CanvasAddEntityButtons = memo(() => {
|
||||
justifyContent="flex-start"
|
||||
leftIcon={<PiPlusBold />}
|
||||
onClick={addRegionalGuidance}
|
||||
isDisabled={isFLUX || isSD3}
|
||||
isDisabled={isSD3}
|
||||
>
|
||||
{t('controlLayers.regionalGuidance')}
|
||||
</Button>
|
||||
|
||||
@@ -49,7 +49,7 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addInpaintMask}>
|
||||
{t('controlLayers.inpaintMask')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addRegionalGuidance} isDisabled={isFLUX || isSD3}>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addRegionalGuidance} isDisabled={isSD3}>
|
||||
{t('controlLayers.regionalGuidance')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addRegionalReferenceImage} isDisabled={isFLUX || isSD3}>
|
||||
|
||||
@@ -1,27 +1,28 @@
|
||||
import { IconButton, Tooltip } from '@invoke-ai/ui-library';
|
||||
import type { IconButtonProps } from '@invoke-ai/ui-library';
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiTrashSimpleFill } from 'react-icons/pi';
|
||||
import { PiXBold } from 'react-icons/pi';
|
||||
|
||||
type Props = {
|
||||
type Props = Omit<IconButtonProps, 'aria-label'> & {
|
||||
onDelete: () => void;
|
||||
};
|
||||
|
||||
export const RegionalGuidanceDeletePromptButton = memo(({ onDelete }: Props) => {
|
||||
export const RegionalGuidanceDeletePromptButton = memo(({ onDelete, ...rest }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
return (
|
||||
<Tooltip label={t('controlLayers.deletePrompt')}>
|
||||
<IconButton
|
||||
variant="link"
|
||||
aria-label={t('controlLayers.deletePrompt')}
|
||||
icon={<PiTrashSimpleFill />}
|
||||
onClick={onDelete}
|
||||
flexGrow={0}
|
||||
size="sm"
|
||||
p={0}
|
||||
colorScheme="error"
|
||||
/>
|
||||
</Tooltip>
|
||||
<IconButton
|
||||
tooltip={t('common.delete')}
|
||||
variant="link"
|
||||
aria-label={t('common.delete')}
|
||||
icon={<PiXBold />}
|
||||
onClick={onDelete}
|
||||
flexGrow={0}
|
||||
size="sm"
|
||||
p={0}
|
||||
colorScheme="error"
|
||||
{...rest}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import { Button, Flex, Text } from '@invoke-ai/ui-library';
|
||||
import { Button, Flex, IconButton, Spacer, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
|
||||
import { rgIPAdapterDeleted } from 'features/controlLayers/store/canvasSlice';
|
||||
import type { SetRegionalGuidanceReferenceImageDndTargetData } from 'features/dnd/dnd';
|
||||
import { setRegionalGuidanceReferenceImageDndTarget } from 'features/dnd/dnd';
|
||||
import { DndDropTarget } from 'features/dnd/DndDropTarget';
|
||||
@@ -10,6 +11,7 @@ import { setRegionalGuidanceReferenceImage } from 'features/imageActions/actions
|
||||
import { activeTabCanvasRightPanelChanged } from 'features/ui/store/uiSlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { Trans, useTranslation } from 'react-i18next';
|
||||
import { PiXBold } from 'react-icons/pi';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
type Props = {
|
||||
@@ -31,6 +33,9 @@ export const RegionalGuidanceIPAdapterSettingsEmptyState = memo(({ referenceImag
|
||||
const onClickGalleryButton = useCallback(() => {
|
||||
dispatch(activeTabCanvasRightPanelChanged('gallery'));
|
||||
}, [dispatch]);
|
||||
const onDeleteIPAdapter = useCallback(() => {
|
||||
dispatch(rgIPAdapterDeleted({ entityIdentifier, referenceImageId }));
|
||||
}, [dispatch, entityIdentifier, referenceImageId]);
|
||||
|
||||
const dndTargetData = useMemo<SetRegionalGuidanceReferenceImageDndTargetData>(
|
||||
() =>
|
||||
@@ -42,26 +47,44 @@ export const RegionalGuidanceIPAdapterSettingsEmptyState = memo(({ referenceImag
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" gap={3} position="relative" w="full" p={4}>
|
||||
<Text textAlign="center" color="base.300">
|
||||
<Trans
|
||||
i18nKey="controlLayers.referenceImageEmptyState"
|
||||
components={{
|
||||
UploadButton: (
|
||||
<Button
|
||||
isDisabled={isBusy}
|
||||
size="sm"
|
||||
variant="link"
|
||||
color="base.300"
|
||||
{...uploadApi.getUploadButtonProps()}
|
||||
/>
|
||||
),
|
||||
GalleryButton: (
|
||||
<Button onClick={onClickGalleryButton} isDisabled={isBusy} size="sm" variant="link" color="base.300" />
|
||||
),
|
||||
}}
|
||||
<Flex flexDir="column" gap={2} position="relative" w="full">
|
||||
<Flex alignItems="center" gap={2}>
|
||||
<Text fontWeight="semibold" color="base.400">
|
||||
{t('controlLayers.referenceImage')}
|
||||
</Text>
|
||||
<Spacer />
|
||||
<IconButton
|
||||
size="sm"
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
icon={<PiXBold />}
|
||||
tooltip={t('controlLayers.deleteReferenceImage')}
|
||||
aria-label={t('controlLayers.deleteReferenceImage')}
|
||||
onClick={onDeleteIPAdapter}
|
||||
colorScheme="error"
|
||||
/>
|
||||
</Text>
|
||||
</Flex>
|
||||
<Flex alignItems="center" gap={2} p={4}>
|
||||
<Text textAlign="center" color="base.300">
|
||||
<Trans
|
||||
i18nKey="controlLayers.referenceImageEmptyState"
|
||||
components={{
|
||||
UploadButton: (
|
||||
<Button
|
||||
isDisabled={isBusy}
|
||||
size="sm"
|
||||
variant="link"
|
||||
color="base.300"
|
||||
{...uploadApi.getUploadButtonProps()}
|
||||
/>
|
||||
),
|
||||
GalleryButton: (
|
||||
<Button onClick={onClickGalleryButton} isDisabled={isBusy} size="sm" variant="link" color="base.300" />
|
||||
),
|
||||
}}
|
||||
/>
|
||||
</Text>
|
||||
</Flex>
|
||||
<input {...uploadApi.getUploadInputProps()} />
|
||||
<DndDropTarget
|
||||
dndTarget={setRegionalGuidanceReferenceImageDndTarget}
|
||||
|
||||
@@ -5,6 +5,7 @@ import { StagingAreaToolbarDiscardSelectedButton } from 'features/controlLayers/
|
||||
import { StagingAreaToolbarImageCountButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarImageCountButton';
|
||||
import { StagingAreaToolbarNextButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarNextButton';
|
||||
import { StagingAreaToolbarPrevButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarPrevButton';
|
||||
import { StagingAreaToolbarSaveAsMenu } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarSaveAsMenu';
|
||||
import { StagingAreaToolbarSaveSelectedToGalleryButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarSaveSelectedToGalleryButton';
|
||||
import { StagingAreaToolbarToggleShowResultsButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarToggleShowResultsButton';
|
||||
import { memo } from 'react';
|
||||
@@ -21,6 +22,7 @@ export const StagingAreaToolbar = memo(() => {
|
||||
<StagingAreaToolbarAcceptButton />
|
||||
<StagingAreaToolbarToggleShowResultsButton />
|
||||
<StagingAreaToolbarSaveSelectedToGalleryButton />
|
||||
<StagingAreaToolbarSaveAsMenu />
|
||||
<StagingAreaToolbarDiscardSelectedButton />
|
||||
<StagingAreaToolbarDiscardAllButton />
|
||||
</ButtonGroup>
|
||||
|
||||
@@ -0,0 +1,136 @@
|
||||
import { IconButton, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
|
||||
import { useAppStore } from 'app/store/nanostores/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { NewLayerIcon } from 'features/controlLayers/components/common/icons';
|
||||
import { selectSelectedImage } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { createNewCanvasEntityFromImage } from 'features/imageActions/actions';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiDotsThreeBold } from 'react-icons/pi';
|
||||
import { imageDTOToFile, uploadImage } from 'services/api/endpoints/images';
|
||||
|
||||
const uploadImageArg = { image_category: 'general', is_intermediate: true, silent: true } as const;
|
||||
|
||||
export const StagingAreaToolbarSaveAsMenu = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const selectedImage = useAppSelector(selectSelectedImage);
|
||||
const store = useAppStore();
|
||||
|
||||
const onClickNewRasterLayerFromImage = useCallback(async () => {
|
||||
if (!selectedImage) {
|
||||
return;
|
||||
}
|
||||
const { dispatch, getState } = store;
|
||||
const file = await imageDTOToFile(selectedImage.imageDTO);
|
||||
const imageDTO = await uploadImage({ file, ...uploadImageArg });
|
||||
createNewCanvasEntityFromImage({
|
||||
imageDTO,
|
||||
type: 'raster_layer',
|
||||
dispatch,
|
||||
getState,
|
||||
overrides: { isEnabled: false }, // We are adding the layer while staging, it should be disabled by default
|
||||
});
|
||||
toast({
|
||||
id: 'SENT_TO_CANVAS',
|
||||
title: t('toast.sentToCanvas'),
|
||||
status: 'success',
|
||||
});
|
||||
}, [selectedImage, store, t]);
|
||||
|
||||
const onClickNewControlLayerFromImage = useCallback(async () => {
|
||||
if (!selectedImage) {
|
||||
return;
|
||||
}
|
||||
const { dispatch, getState } = store;
|
||||
const file = await imageDTOToFile(selectedImage.imageDTO);
|
||||
const imageDTO = await uploadImage({ file, ...uploadImageArg });
|
||||
createNewCanvasEntityFromImage({
|
||||
imageDTO,
|
||||
type: 'control_layer',
|
||||
dispatch,
|
||||
getState,
|
||||
overrides: { isEnabled: false }, // We are adding the layer while staging, it should be disabled by default
|
||||
});
|
||||
toast({
|
||||
id: 'SENT_TO_CANVAS',
|
||||
title: t('toast.sentToCanvas'),
|
||||
status: 'success',
|
||||
});
|
||||
}, [selectedImage, store, t]);
|
||||
|
||||
const onClickNewInpaintMaskFromImage = useCallback(async () => {
|
||||
if (!selectedImage) {
|
||||
return;
|
||||
}
|
||||
const { dispatch, getState } = store;
|
||||
const file = await imageDTOToFile(selectedImage.imageDTO);
|
||||
const imageDTO = await uploadImage({ file, ...uploadImageArg });
|
||||
createNewCanvasEntityFromImage({
|
||||
imageDTO,
|
||||
type: 'inpaint_mask',
|
||||
dispatch,
|
||||
getState,
|
||||
overrides: { isEnabled: false }, // We are adding the layer while staging, it should be disabled by default
|
||||
});
|
||||
toast({
|
||||
id: 'SENT_TO_CANVAS',
|
||||
title: t('toast.sentToCanvas'),
|
||||
status: 'success',
|
||||
});
|
||||
}, [selectedImage, store, t]);
|
||||
|
||||
const onClickNewRegionalGuidanceFromImage = useCallback(async () => {
|
||||
if (!selectedImage) {
|
||||
return;
|
||||
}
|
||||
const { dispatch, getState } = store;
|
||||
const file = await imageDTOToFile(selectedImage.imageDTO);
|
||||
const imageDTO = await uploadImage({ file, ...uploadImageArg });
|
||||
createNewCanvasEntityFromImage({
|
||||
imageDTO,
|
||||
type: 'regional_guidance',
|
||||
dispatch,
|
||||
getState,
|
||||
overrides: { isEnabled: false }, // We are adding the layer while staging, it should be disabled by default
|
||||
});
|
||||
toast({
|
||||
id: 'SENT_TO_CANVAS',
|
||||
title: t('toast.sentToCanvas'),
|
||||
status: 'success',
|
||||
});
|
||||
}, [selectedImage, store, t]);
|
||||
|
||||
return (
|
||||
<Menu>
|
||||
<MenuButton
|
||||
as={IconButton}
|
||||
aria-label={t('controlLayers.newLayerFromImage')}
|
||||
tooltip={t('controlLayers.newLayerFromImage')}
|
||||
icon={<PiDotsThreeBold />}
|
||||
colorScheme="invokeBlue"
|
||||
isDisabled={!selectedImage}
|
||||
/>
|
||||
<MenuList>
|
||||
<MenuItem icon={<NewLayerIcon />} onClickCapture={onClickNewInpaintMaskFromImage} isDisabled={!selectedImage}>
|
||||
{t('controlLayers.inpaintMask')}
|
||||
</MenuItem>
|
||||
<MenuItem
|
||||
icon={<NewLayerIcon />}
|
||||
onClickCapture={onClickNewRegionalGuidanceFromImage}
|
||||
isDisabled={!selectedImage}
|
||||
>
|
||||
{t('controlLayers.regionalGuidance')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<NewLayerIcon />} onClickCapture={onClickNewControlLayerFromImage} isDisabled={!selectedImage}>
|
||||
{t('controlLayers.controlLayer')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<NewLayerIcon />} onClickCapture={onClickNewRasterLayerFromImage} isDisabled={!selectedImage}>
|
||||
{t('controlLayers.rasterLayer')}
|
||||
</MenuItem>
|
||||
</MenuList>
|
||||
</Menu>
|
||||
);
|
||||
});
|
||||
|
||||
StagingAreaToolbarSaveAsMenu.displayName = 'StagingAreaToolbarSaveAsMenu';
|
||||
@@ -1,6 +1,4 @@
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { $authToken } from 'app/store/nanostores/authToken';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { withResultAsync } from 'common/util/result';
|
||||
import { selectSelectedImage } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
@@ -9,14 +7,13 @@ import { toast } from 'features/toast/toast';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiFloppyDiskBold } from 'react-icons/pi';
|
||||
import { uploadImage } from 'services/api/endpoints/images';
|
||||
import { imageDTOToFile, uploadImage } from 'services/api/endpoints/images';
|
||||
|
||||
const TOAST_ID = 'SAVE_STAGING_AREA_IMAGE_TO_GALLERY';
|
||||
|
||||
export const StagingAreaToolbarSaveSelectedToGalleryButton = memo(() => {
|
||||
const autoAddBoardId = useAppSelector(selectAutoAddBoardId);
|
||||
const selectedImage = useAppSelector(selectSelectedImage);
|
||||
const authToken = useStore($authToken);
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
@@ -28,18 +25,8 @@ export const StagingAreaToolbarSaveSelectedToGalleryButton = memo(() => {
|
||||
// To save the image to gallery, we will download it and re-upload it. This allows the user to delete the image
|
||||
// the gallery without borking the canvas, which may need this image to exist.
|
||||
const result = await withResultAsync(async () => {
|
||||
// Download the image
|
||||
const requestOpts = authToken
|
||||
? {
|
||||
headers: {
|
||||
Authorization: `Bearer ${authToken}`,
|
||||
},
|
||||
}
|
||||
: {};
|
||||
const res = await fetch(selectedImage.imageDTO.image_url, requestOpts);
|
||||
const blob = await res.blob();
|
||||
// Create a new file with the same name, which we will upload
|
||||
const file = new File([blob], `copy_of_${selectedImage.imageDTO.image_name}`, { type: 'image/png' });
|
||||
const file = await imageDTOToFile(selectedImage.imageDTO);
|
||||
|
||||
await uploadImage({
|
||||
file,
|
||||
@@ -66,7 +53,7 @@ export const StagingAreaToolbarSaveSelectedToGalleryButton = memo(() => {
|
||||
status: 'error',
|
||||
});
|
||||
}
|
||||
}, [autoAddBoardId, selectedImage, t, authToken]);
|
||||
}, [autoAddBoardId, selectedImage, t]);
|
||||
|
||||
return (
|
||||
<IconButton
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { CanvasEntityDeleteButton } from 'features/controlLayers/components/common/CanvasEntityDeleteButton';
|
||||
import { CanvasEntityEnabledToggle } from 'features/controlLayers/components/common/CanvasEntityEnabledToggle';
|
||||
import { CanvasEntityHeaderWarnings } from 'features/controlLayers/components/common/CanvasEntityHeaderWarnings';
|
||||
import { CanvasEntityIsBookmarkedForQuickSwitchToggle } from 'features/controlLayers/components/common/CanvasEntityIsBookmarkedForQuickSwitchToggle';
|
||||
import { CanvasEntityIsLockedToggle } from 'features/controlLayers/components/common/CanvasEntityIsLockedToggle';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
@@ -11,6 +12,7 @@ export const CanvasEntityHeaderCommonActions = memo(() => {
|
||||
|
||||
return (
|
||||
<Flex alignSelf="stretch">
|
||||
<CanvasEntityHeaderWarnings />
|
||||
<CanvasEntityIsBookmarkedForQuickSwitchToggle />
|
||||
{entityIdentifier.type !== 'reference_image' && <CanvasEntityIsLockedToggle />}
|
||||
<CanvasEntityEnabledToggle />
|
||||
|
||||
@@ -0,0 +1,101 @@
|
||||
import { Flex, IconButton, ListItem, Text, UnorderedList } from '@invoke-ai/ui-library';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { useEntityIsEnabled } from 'features/controlLayers/hooks/useEntityIsEnabled';
|
||||
import { selectModel } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/store/selectors';
|
||||
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import {
|
||||
getControlLayerWarnings,
|
||||
getGlobalReferenceImageWarnings,
|
||||
getInpaintMaskWarnings,
|
||||
getRasterLayerWarnings,
|
||||
getRegionalGuidanceWarnings,
|
||||
} from 'features/controlLayers/store/validators';
|
||||
import type { TFunction } from 'i18next';
|
||||
import { upperFirst } from 'lodash-es';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiWarningBold } from 'react-icons/pi';
|
||||
import type { Equals } from 'tsafe';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const buildSelectWarnings = (entityIdentifier: CanvasEntityIdentifier, t: TFunction) => {
|
||||
return createSelector(selectCanvasSlice, selectModel, (canvas, model) => {
|
||||
// This component is used within a <CanvasEntityStateGate /> so we can safely assume that the entity exists.
|
||||
// Should never throw.
|
||||
const entity = selectEntityOrThrow(canvas, entityIdentifier, 'CanvasEntityHeaderWarnings');
|
||||
|
||||
let warnings: string[] = [];
|
||||
|
||||
const entityType = entity.type;
|
||||
|
||||
if (entityType === 'control_layer') {
|
||||
warnings = getControlLayerWarnings(entity, model);
|
||||
} else if (entityType === 'regional_guidance') {
|
||||
warnings = getRegionalGuidanceWarnings(entity, model);
|
||||
} else if (entityType === 'inpaint_mask') {
|
||||
warnings = getInpaintMaskWarnings(entity, model);
|
||||
} else if (entityType === 'raster_layer') {
|
||||
warnings = getRasterLayerWarnings(entity, model);
|
||||
} else if (entityType === 'reference_image') {
|
||||
warnings = getGlobalReferenceImageWarnings(entity, model);
|
||||
} else {
|
||||
assert<Equals<typeof entityType, never>>(false, 'Unexpected entity type');
|
||||
}
|
||||
|
||||
// Return a stable reference if there are no warnings
|
||||
if (warnings.length === 0) {
|
||||
return EMPTY_ARRAY;
|
||||
}
|
||||
|
||||
return warnings.map((w) => t(w)).map(upperFirst);
|
||||
});
|
||||
};
|
||||
|
||||
export const CanvasEntityHeaderWarnings = memo(() => {
|
||||
const entityIdentifier = useEntityIdentifierContext();
|
||||
const { t } = useTranslation();
|
||||
const isEnabled = useEntityIsEnabled(entityIdentifier);
|
||||
const selectWarnings = useMemo(() => buildSelectWarnings(entityIdentifier, t), [entityIdentifier, t]);
|
||||
const warnings = useAppSelector(selectWarnings);
|
||||
|
||||
if (warnings.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
// Using IconButton here bc it matches the styling of the actual buttons in the header without any fanagling, but
|
||||
// it's not a button
|
||||
<IconButton
|
||||
as="span"
|
||||
size="sm"
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
aria-label="warnings"
|
||||
tooltip={<TooltipContent warnings={warnings} />}
|
||||
icon={<PiWarningBold />}
|
||||
colorScheme="warning"
|
||||
isDisabled={!isEnabled}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
CanvasEntityHeaderWarnings.displayName = 'CanvasEntityHeaderWarnings';
|
||||
|
||||
const TooltipContent = memo((props: { warnings: string[] }) => {
|
||||
const { t } = useTranslation();
|
||||
return (
|
||||
<Flex flexDir="column">
|
||||
<Text>{t('controlLayers.warnings.problemsFound')}:</Text>
|
||||
<UnorderedList>
|
||||
{props.warnings.map((warning, index) => (
|
||||
<ListItem key={index}>{warning}</ListItem>
|
||||
))}
|
||||
</UnorderedList>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
TooltipContent.displayName = 'TooltipContent';
|
||||
@@ -29,7 +29,13 @@ import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/
|
||||
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||
import { isControlNetOrT2IAdapterModelConfig, isIPAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
/** @knipignore */
|
||||
/**
|
||||
* Selects the default control adapter configuration based on the model configurations and the base.
|
||||
*
|
||||
* Be sure to clone the output of this selector before modifying it!
|
||||
*
|
||||
* @knipignore
|
||||
*/
|
||||
export const selectDefaultControlAdapter = createSelector(
|
||||
selectModelConfigsQuery,
|
||||
selectBase,
|
||||
@@ -52,6 +58,11 @@ export const selectDefaultControlAdapter = createSelector(
|
||||
}
|
||||
);
|
||||
|
||||
/**
|
||||
* Selects the default IP adapter configuration based on the model configurations and the base.
|
||||
*
|
||||
* Be sure to clone the output of this selector before modifying it!
|
||||
*/
|
||||
export const selectDefaultIPAdapter = createSelector(
|
||||
selectModelConfigsQuery,
|
||||
selectBase,
|
||||
@@ -117,7 +128,9 @@ export const useAddRegionalReferenceImage = () => {
|
||||
|
||||
const func = useCallback(() => {
|
||||
const overrides: Partial<CanvasRegionalGuidanceState> = {
|
||||
referenceImages: [{ id: getPrefixedId('regional_guidance_reference_image'), ipAdapter: defaultIPAdapter }],
|
||||
referenceImages: [
|
||||
{ id: getPrefixedId('regional_guidance_reference_image'), ipAdapter: deepClone(defaultIPAdapter) },
|
||||
],
|
||||
};
|
||||
dispatch(rgAdded({ isSelected: true, overrides }));
|
||||
}, [defaultIPAdapter, dispatch]);
|
||||
@@ -129,7 +142,7 @@ export const useAddGlobalReferenceImage = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const defaultIPAdapter = useAppSelector(selectDefaultIPAdapter);
|
||||
const func = useCallback(() => {
|
||||
const overrides = { ipAdapter: defaultIPAdapter };
|
||||
const overrides = { ipAdapter: deepClone(defaultIPAdapter) };
|
||||
dispatch(referenceImageAdded({ isSelected: true, overrides }));
|
||||
}, [defaultIPAdapter, dispatch]);
|
||||
|
||||
@@ -140,7 +153,7 @@ export const useAddRegionalGuidanceIPAdapter = (entityIdentifier: CanvasEntityId
|
||||
const dispatch = useAppDispatch();
|
||||
const defaultIPAdapter = useAppSelector(selectDefaultIPAdapter);
|
||||
const func = useCallback(() => {
|
||||
dispatch(rgIPAdapterAdded({ entityIdentifier, overrides: { ipAdapter: defaultIPAdapter } }));
|
||||
dispatch(rgIPAdapterAdded({ entityIdentifier, overrides: { ipAdapter: deepClone(defaultIPAdapter) } }));
|
||||
}, [defaultIPAdapter, dispatch, entityIdentifier]);
|
||||
|
||||
return func;
|
||||
|
||||
@@ -0,0 +1,153 @@
|
||||
import type {
|
||||
CanvasControlLayerState,
|
||||
CanvasInpaintMaskState,
|
||||
CanvasRasterLayerState,
|
||||
CanvasReferenceImageState,
|
||||
CanvasRegionalGuidanceState,
|
||||
} from 'features/controlLayers/store/types';
|
||||
import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
|
||||
|
||||
const WARNINGS = {
|
||||
UNSUPPORTED_MODEL: 'controlLayers.warnings.unsupportedModel',
|
||||
RG_NO_PROMPTS_OR_IP_ADAPTERS: 'controlLayers.warnings.rgNoPromptsOrIPAdapters',
|
||||
RG_NEGATIVE_PROMPT_NOT_SUPPORTED: 'controlLayers.warnings.rgNegativePromptNotSupported',
|
||||
RG_REFERENCE_IMAGES_NOT_SUPPORTED: 'controlLayers.warnings.rgReferenceImagesNotSupported',
|
||||
RG_AUTO_NEGATIVE_NOT_SUPPORTED: 'controlLayers.warnings.rgAutoNegativeNotSupported',
|
||||
RG_NO_REGION: 'controlLayers.warnings.rgNoRegion',
|
||||
IP_ADAPTER_NO_MODEL_SELECTED: 'controlLayers.warnings.ipAdapterNoModelSelected',
|
||||
IP_ADAPTER_INCOMPATIBLE_BASE_MODEL: 'controlLayers.warnings.ipAdapterIncompatibleBaseModel',
|
||||
IP_ADAPTER_NO_IMAGE_SELECTED: 'controlLayers.warnings.ipAdapterNoImageSelected',
|
||||
CONTROL_ADAPTER_NO_MODEL_SELECTED: 'controlLayers.warnings.controlAdapterNoModelSelected',
|
||||
CONTROL_ADAPTER_INCOMPATIBLE_BASE_MODEL: 'controlLayers.warnings.controlAdapterIncompatibleBaseModel',
|
||||
CONTROL_ADAPTER_NO_CONTROL: 'controlLayers.warnings.controlAdapterNoControl',
|
||||
} as const;
|
||||
|
||||
type WarningTKey = (typeof WARNINGS)[keyof typeof WARNINGS];
|
||||
|
||||
export const getRegionalGuidanceWarnings = (
|
||||
entity: CanvasRegionalGuidanceState,
|
||||
model: ParameterModel | null
|
||||
): WarningTKey[] => {
|
||||
const warnings: WarningTKey[] = [];
|
||||
|
||||
if (entity.objects.length === 0) {
|
||||
// Layer is in empty state
|
||||
warnings.push(WARNINGS.RG_NO_REGION);
|
||||
}
|
||||
|
||||
if (entity.positivePrompt === null && entity.negativePrompt === null && entity.referenceImages.length === 0) {
|
||||
// Must have at least 1 prompt or IP Adapter
|
||||
warnings.push(WARNINGS.RG_NO_PROMPTS_OR_IP_ADAPTERS);
|
||||
}
|
||||
|
||||
if (model) {
|
||||
if (model.base === 'sd-3' || model.base === 'sd-2') {
|
||||
// Unsupported model architecture
|
||||
warnings.push(WARNINGS.UNSUPPORTED_MODEL);
|
||||
} else if (model.base === 'flux') {
|
||||
// Some features are not supported for flux models
|
||||
if (entity.negativePrompt !== null) {
|
||||
warnings.push(WARNINGS.RG_NEGATIVE_PROMPT_NOT_SUPPORTED);
|
||||
}
|
||||
if (entity.referenceImages.length > 0) {
|
||||
warnings.push(WARNINGS.RG_REFERENCE_IMAGES_NOT_SUPPORTED);
|
||||
}
|
||||
if (entity.autoNegative) {
|
||||
warnings.push(WARNINGS.RG_AUTO_NEGATIVE_NOT_SUPPORTED);
|
||||
}
|
||||
} else {
|
||||
entity.referenceImages.forEach(({ ipAdapter }) => {
|
||||
if (!ipAdapter.model) {
|
||||
// No model selected
|
||||
warnings.push(WARNINGS.IP_ADAPTER_NO_MODEL_SELECTED);
|
||||
} else if (ipAdapter.model.base !== model.base) {
|
||||
// Supported model architecture but doesn't match
|
||||
warnings.push(WARNINGS.IP_ADAPTER_INCOMPATIBLE_BASE_MODEL);
|
||||
}
|
||||
|
||||
if (!ipAdapter.image) {
|
||||
// No image selected
|
||||
warnings.push(WARNINGS.IP_ADAPTER_NO_IMAGE_SELECTED);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return warnings;
|
||||
};
|
||||
|
||||
export const getGlobalReferenceImageWarnings = (
|
||||
entity: CanvasReferenceImageState,
|
||||
model: ParameterModel | null
|
||||
): WarningTKey[] => {
|
||||
const warnings: WarningTKey[] = [];
|
||||
|
||||
if (!entity.ipAdapter.model) {
|
||||
// No model selected
|
||||
warnings.push(WARNINGS.IP_ADAPTER_NO_MODEL_SELECTED);
|
||||
} else if (model) {
|
||||
if (model.base === 'sd-3' || model.base === 'sd-2') {
|
||||
// Unsupported model architecture
|
||||
warnings.push(WARNINGS.UNSUPPORTED_MODEL);
|
||||
} else if (entity.ipAdapter.model.base !== model.base) {
|
||||
// Supported model architecture but doesn't match
|
||||
warnings.push(WARNINGS.IP_ADAPTER_INCOMPATIBLE_BASE_MODEL);
|
||||
}
|
||||
}
|
||||
|
||||
if (!entity.ipAdapter.image) {
|
||||
// No image selected
|
||||
warnings.push(WARNINGS.IP_ADAPTER_NO_IMAGE_SELECTED);
|
||||
}
|
||||
|
||||
return warnings;
|
||||
};
|
||||
|
||||
export const getControlLayerWarnings = (
|
||||
entity: CanvasControlLayerState,
|
||||
model: ParameterModel | null
|
||||
): WarningTKey[] => {
|
||||
const warnings: WarningTKey[] = [];
|
||||
|
||||
if (entity.objects.length === 0) {
|
||||
// Layer is in empty state
|
||||
warnings.push(WARNINGS.CONTROL_ADAPTER_NO_CONTROL);
|
||||
}
|
||||
|
||||
if (!entity.controlAdapter.model) {
|
||||
// No model selected
|
||||
warnings.push(WARNINGS.CONTROL_ADAPTER_NO_MODEL_SELECTED);
|
||||
} else if (model) {
|
||||
if (model.base === 'sd-3' || model.base === 'sd-2') {
|
||||
// Unsupported model architecture
|
||||
warnings.push(WARNINGS.UNSUPPORTED_MODEL);
|
||||
} else if (entity.controlAdapter.model.base !== model.base) {
|
||||
// Supported model architecture but doesn't match
|
||||
warnings.push(WARNINGS.CONTROL_ADAPTER_INCOMPATIBLE_BASE_MODEL);
|
||||
}
|
||||
}
|
||||
|
||||
return warnings;
|
||||
};
|
||||
|
||||
export const getRasterLayerWarnings = (
|
||||
_entity: CanvasRasterLayerState,
|
||||
_model: ParameterModel | null
|
||||
): WarningTKey[] => {
|
||||
const warnings: WarningTKey[] = [];
|
||||
|
||||
// There are no warnings at the moment for raster layers.
|
||||
|
||||
return warnings;
|
||||
};
|
||||
|
||||
export const getInpaintMaskWarnings = (
|
||||
_entity: CanvasInpaintMaskState,
|
||||
_model: ParameterModel | null
|
||||
): WarningTKey[] => {
|
||||
const warnings: WarningTKey[] = [];
|
||||
|
||||
// There are no warnings at the moment for inpaint masks.
|
||||
|
||||
return warnings;
|
||||
};
|
||||
@@ -77,6 +77,32 @@ export const ImageMenuItemNewLayerFromImageSubMenu = memo(() => {
|
||||
});
|
||||
}, [imageDTO, imageViewer, store, t]);
|
||||
|
||||
const onClickNewRegionalReferenceImageFromImage = useCallback(() => {
|
||||
const { dispatch, getState } = store;
|
||||
createNewCanvasEntityFromImage({ imageDTO, type: 'reference_image', dispatch, getState });
|
||||
dispatch(sentImageToCanvas());
|
||||
dispatch(setActiveTab('canvas'));
|
||||
imageViewer.close();
|
||||
toast({
|
||||
id: 'SENT_TO_CANVAS',
|
||||
title: t('toast.sentToCanvas'),
|
||||
status: 'success',
|
||||
});
|
||||
}, [imageDTO, imageViewer, store, t]);
|
||||
|
||||
const onClickNewGlobalReferenceImageFromImage = useCallback(() => {
|
||||
const { dispatch, getState } = store;
|
||||
createNewCanvasEntityFromImage({ imageDTO, type: 'regional_guidance_with_reference_image', dispatch, getState });
|
||||
dispatch(sentImageToCanvas());
|
||||
dispatch(setActiveTab('canvas'));
|
||||
imageViewer.close();
|
||||
toast({
|
||||
id: 'SENT_TO_CANVAS',
|
||||
title: t('toast.sentToCanvas'),
|
||||
status: 'success',
|
||||
});
|
||||
}, [imageDTO, imageViewer, store, t]);
|
||||
|
||||
return (
|
||||
<MenuItem {...subMenu.parentMenuItemProps} icon={<PiPlusBold />}>
|
||||
<Menu {...subMenu.menuProps}>
|
||||
@@ -104,6 +130,20 @@ export const ImageMenuItemNewLayerFromImageSubMenu = memo(() => {
|
||||
<MenuItem icon={<NewLayerIcon />} onClickCapture={onClickNewRasterLayerFromImage} isDisabled={isBusy}>
|
||||
{t('controlLayers.rasterLayer')}
|
||||
</MenuItem>
|
||||
<MenuItem
|
||||
icon={<NewLayerIcon />}
|
||||
onClickCapture={onClickNewRegionalReferenceImageFromImage}
|
||||
isDisabled={isBusy}
|
||||
>
|
||||
{t('controlLayers.referenceImageRegional')}
|
||||
</MenuItem>
|
||||
<MenuItem
|
||||
icon={<NewLayerIcon />}
|
||||
onClickCapture={onClickNewGlobalReferenceImageFromImage}
|
||||
isDisabled={isBusy}
|
||||
>
|
||||
{t('controlLayers.referenceImageGlobal')}
|
||||
</MenuItem>
|
||||
</MenuList>
|
||||
</Menu>
|
||||
</MenuItem>
|
||||
|
||||
@@ -20,6 +20,7 @@ import { selectBboxModelBase, selectBboxRect } from 'features/controlLayers/stor
|
||||
import type {
|
||||
CanvasControlLayerState,
|
||||
CanvasEntityIdentifier,
|
||||
CanvasEntityState,
|
||||
CanvasEntityType,
|
||||
CanvasInpaintMaskState,
|
||||
CanvasRasterLayerState,
|
||||
@@ -134,14 +135,16 @@ export const createNewCanvasEntityFromImage = (arg: {
|
||||
type: CanvasEntityType | 'regional_guidance_with_reference_image';
|
||||
dispatch: AppDispatch;
|
||||
getState: () => RootState;
|
||||
overrides?: Partial<Pick<CanvasEntityState, 'isEnabled' | 'isLocked' | 'name'>>;
|
||||
}) => {
|
||||
const { type, imageDTO, dispatch, getState } = arg;
|
||||
const { type, imageDTO, dispatch, getState, overrides: _overrides } = arg;
|
||||
const state = getState();
|
||||
const imageObject = imageDTOToImageObject(imageDTO);
|
||||
const { x, y } = selectBboxRect(state);
|
||||
const overrides = {
|
||||
objects: [imageObject],
|
||||
position: { x, y },
|
||||
..._overrides,
|
||||
};
|
||||
switch (type) {
|
||||
case 'raster_layer': {
|
||||
@@ -166,13 +169,13 @@ export const createNewCanvasEntityFromImage = (arg: {
|
||||
break;
|
||||
}
|
||||
case 'reference_image': {
|
||||
const ipAdapter = selectDefaultIPAdapter(getState());
|
||||
const ipAdapter = deepClone(selectDefaultIPAdapter(getState()));
|
||||
ipAdapter.image = imageDTOToImageWithDims(imageDTO);
|
||||
dispatch(referenceImageAdded({ overrides: { ipAdapter }, isSelected: true }));
|
||||
break;
|
||||
}
|
||||
case 'regional_guidance_with_reference_image': {
|
||||
const ipAdapter = selectDefaultIPAdapter(getState());
|
||||
const ipAdapter = deepClone(selectDefaultIPAdapter(getState()));
|
||||
ipAdapter.image = imageDTOToImageWithDims(imageDTO);
|
||||
const referenceImages = [{ id: getPrefixedId('regional_guidance_reference_image'), ipAdapter }];
|
||||
dispatch(rgAdded({ overrides: { referenceImages }, isSelected: true }));
|
||||
@@ -288,14 +291,14 @@ export const newCanvasFromImage = (arg: {
|
||||
break;
|
||||
}
|
||||
case 'reference_image': {
|
||||
const ipAdapter = selectDefaultIPAdapter(getState());
|
||||
const ipAdapter = deepClone(selectDefaultIPAdapter(getState()));
|
||||
ipAdapter.image = imageDTOToImageWithDims(imageDTO);
|
||||
dispatch(canvasReset());
|
||||
dispatch(referenceImageAdded({ overrides: { ipAdapter }, isSelected: true }));
|
||||
break;
|
||||
}
|
||||
case 'regional_guidance_with_reference_image': {
|
||||
const ipAdapter = selectDefaultIPAdapter(getState());
|
||||
const ipAdapter = deepClone(selectDefaultIPAdapter(getState()));
|
||||
ipAdapter.image = imageDTOToImageWithDims(imageDTO);
|
||||
const referenceImages = [{ id: getPrefixedId('regional_guidance_reference_image'), ipAdapter }];
|
||||
dispatch(canvasReset());
|
||||
|
||||
@@ -1,35 +1,41 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { withResultAsync } from 'common/util/result';
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import type {
|
||||
CanvasControlLayerState,
|
||||
ControlNetConfig,
|
||||
Rect,
|
||||
T2IAdapterConfig,
|
||||
} from 'features/controlLayers/store/types';
|
||||
import type { CanvasControlLayerState, Rect } from 'features/controlLayers/store/types';
|
||||
import { getControlLayerWarnings } from 'features/controlLayers/store/validators';
|
||||
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import type { BaseModelType, ImageDTO, Invocation } from 'services/api/types';
|
||||
import type { ImageDTO, Invocation } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const log = logger('system');
|
||||
|
||||
type AddControlNetsArg = {
|
||||
manager: CanvasManager;
|
||||
entities: CanvasControlLayerState[];
|
||||
g: Graph;
|
||||
rect: Rect;
|
||||
collector: Invocation<'collect'>;
|
||||
model: ParameterModel;
|
||||
};
|
||||
|
||||
type AddControlNetsResult = {
|
||||
addedControlNets: number;
|
||||
};
|
||||
|
||||
export const addControlNets = async (
|
||||
manager: CanvasManager,
|
||||
layers: CanvasControlLayerState[],
|
||||
g: Graph,
|
||||
rect: Rect,
|
||||
collector: Invocation<'collect'>,
|
||||
base: BaseModelType
|
||||
): Promise<AddControlNetsResult> => {
|
||||
const validControlLayers = layers
|
||||
.filter((layer) => layer.isEnabled)
|
||||
.filter((layer) => isValidControlAdapter(layer.controlAdapter, base))
|
||||
.filter((layer) => layer.controlAdapter.type === 'controlnet');
|
||||
export const addControlNets = async ({
|
||||
manager,
|
||||
entities,
|
||||
g,
|
||||
rect,
|
||||
collector,
|
||||
model,
|
||||
}: AddControlNetsArg): Promise<AddControlNetsResult> => {
|
||||
const validControlLayers = entities
|
||||
.filter((entity) => entity.isEnabled)
|
||||
.filter((entity) => entity.controlAdapter.type === 'controlnet')
|
||||
.filter((entity) => getControlLayerWarnings(entity, model).length === 0);
|
||||
|
||||
const result: AddControlNetsResult = {
|
||||
addedControlNets: 0,
|
||||
@@ -54,22 +60,31 @@ export const addControlNets = async (
|
||||
return result;
|
||||
};
|
||||
|
||||
type AddT2IAdaptersArg = {
|
||||
manager: CanvasManager;
|
||||
entities: CanvasControlLayerState[];
|
||||
g: Graph;
|
||||
rect: Rect;
|
||||
collector: Invocation<'collect'>;
|
||||
model: ParameterModel;
|
||||
};
|
||||
|
||||
type AddT2IAdaptersResult = {
|
||||
addedT2IAdapters: number;
|
||||
};
|
||||
|
||||
export const addT2IAdapters = async (
|
||||
manager: CanvasManager,
|
||||
layers: CanvasControlLayerState[],
|
||||
g: Graph,
|
||||
rect: Rect,
|
||||
collector: Invocation<'collect'>,
|
||||
base: BaseModelType
|
||||
): Promise<AddT2IAdaptersResult> => {
|
||||
const validControlLayers = layers
|
||||
.filter((layer) => layer.isEnabled)
|
||||
.filter((layer) => isValidControlAdapter(layer.controlAdapter, base))
|
||||
.filter((layer) => layer.controlAdapter.type === 't2i_adapter');
|
||||
export const addT2IAdapters = async ({
|
||||
manager,
|
||||
entities,
|
||||
g,
|
||||
rect,
|
||||
collector,
|
||||
model,
|
||||
}: AddT2IAdaptersArg): Promise<AddT2IAdaptersResult> => {
|
||||
const validControlLayers = entities
|
||||
.filter((entity) => entity.isEnabled)
|
||||
.filter((entity) => entity.controlAdapter.type === 't2i_adapter')
|
||||
.filter((entity) => getControlLayerWarnings(entity, model).length === 0);
|
||||
|
||||
const result: AddT2IAdaptersResult = {
|
||||
addedT2IAdapters: 0,
|
||||
@@ -145,11 +160,3 @@ const addT2IAdapterToGraph = (
|
||||
|
||||
g.addEdge(t2iAdapter, 't2i_adapter', collector, 'item');
|
||||
};
|
||||
|
||||
const isValidControlAdapter = (controlAdapter: ControlNetConfig | T2IAdapterConfig, base: BaseModelType): boolean => {
|
||||
// Must be have a model
|
||||
const hasModel = Boolean(controlAdapter.model);
|
||||
// Model must match the current base model
|
||||
const modelMatchesBase = controlAdapter.model?.base === base;
|
||||
return hasModel && modelMatchesBase;
|
||||
};
|
||||
|
||||
@@ -1,19 +1,23 @@
|
||||
import type { CanvasReferenceImageState } from 'features/controlLayers/store/types';
|
||||
import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators';
|
||||
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import type { BaseModelType, Invocation } from 'services/api/types';
|
||||
import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
|
||||
import type { Invocation } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
type AddIPAdaptersResult = {
|
||||
addedIPAdapters: number;
|
||||
};
|
||||
|
||||
export const addIPAdapters = (
|
||||
ipAdapters: CanvasReferenceImageState[],
|
||||
g: Graph,
|
||||
collector: Invocation<'collect'>,
|
||||
base: BaseModelType
|
||||
): AddIPAdaptersResult => {
|
||||
const validIPAdapters = ipAdapters.filter((entity) => isValidIPAdapter(entity, base));
|
||||
type AddIPAdaptersArg = {
|
||||
entities: CanvasReferenceImageState[];
|
||||
g: Graph;
|
||||
collector: Invocation<'collect'>;
|
||||
model: ParameterModel;
|
||||
};
|
||||
|
||||
export const addIPAdapters = ({ entities, g, collector, model }: AddIPAdaptersArg): AddIPAdaptersResult => {
|
||||
const validIPAdapters = entities.filter((entity) => getGlobalReferenceImageWarnings(entity, model).length === 0);
|
||||
|
||||
const result: AddIPAdaptersResult = {
|
||||
addedIPAdapters: 0,
|
||||
@@ -76,11 +80,3 @@ const addIPAdapter = (entity: CanvasReferenceImageState, g: Graph, collector: In
|
||||
|
||||
g.addEdge(ipAdapterNode, 'ip_adapter', collector, 'item');
|
||||
};
|
||||
|
||||
const isValidIPAdapter = ({ isEnabled, ipAdapter }: CanvasReferenceImageState, base: BaseModelType): boolean => {
|
||||
// Must be have a model that matches the current base and must have a control image
|
||||
const hasModel = Boolean(ipAdapter.model);
|
||||
const modelMatchesBase = ipAdapter.model?.base === base;
|
||||
const hasImage = Boolean(ipAdapter.image);
|
||||
return isEnabled && hasModel && modelMatchesBase && hasImage;
|
||||
};
|
||||
|
||||
@@ -3,15 +3,12 @@ import { deepClone } from 'common/util/deepClone';
|
||||
import { withResultAsync } from 'common/util/result';
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import type {
|
||||
CanvasRegionalGuidanceState,
|
||||
IPAdapterConfig,
|
||||
Rect,
|
||||
RegionalGuidanceReferenceImageState,
|
||||
} from 'features/controlLayers/store/types';
|
||||
import type { CanvasRegionalGuidanceState, Rect } from 'features/controlLayers/store/types';
|
||||
import { getRegionalGuidanceWarnings } from 'features/controlLayers/store/validators';
|
||||
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import type { BaseModelType, Invocation } from 'services/api/types';
|
||||
import type { Invocation } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const log = logger('system');
|
||||
@@ -23,19 +20,26 @@ type AddedRegionResult = {
|
||||
addedIPAdapters: number;
|
||||
};
|
||||
|
||||
const isValidRegion = (rg: CanvasRegionalGuidanceState, base: BaseModelType) => {
|
||||
const isEnabled = rg.isEnabled;
|
||||
const hasTextPrompt = Boolean(rg.positivePrompt || rg.negativePrompt);
|
||||
const hasIPAdapter = rg.referenceImages.filter(({ ipAdapter }) => isValidIPAdapter(ipAdapter, base)).length > 0;
|
||||
return isEnabled && (hasTextPrompt || hasIPAdapter);
|
||||
type AddRegionsArg = {
|
||||
manager: CanvasManager;
|
||||
regions: CanvasRegionalGuidanceState[];
|
||||
g: Graph;
|
||||
bbox: Rect;
|
||||
model: ParameterModel;
|
||||
posCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>;
|
||||
negCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'> | null;
|
||||
posCondCollect: Invocation<'collect'>;
|
||||
negCondCollect: Invocation<'collect'> | null;
|
||||
ipAdapterCollect: Invocation<'collect'>;
|
||||
};
|
||||
|
||||
/**
|
||||
* Adds regional guidance to the graph
|
||||
* @param manager The canvas manager
|
||||
* @param regions Array of regions to add
|
||||
* @param g The graph to add the layers to
|
||||
* @param base The base model type
|
||||
* @param denoise The main denoise node
|
||||
* @param bbox The bounding box
|
||||
* @param model The main model
|
||||
* @param posCond The positive conditioning node
|
||||
* @param negCond The negative conditioning node
|
||||
* @param posCondCollect The positive conditioning collector
|
||||
@@ -44,22 +48,28 @@ const isValidRegion = (rg: CanvasRegionalGuidanceState, base: BaseModelType) =>
|
||||
* @returns A promise that resolves to the regions that were successfully added to the graph
|
||||
*/
|
||||
|
||||
export const addRegions = async (
|
||||
manager: CanvasManager,
|
||||
regions: CanvasRegionalGuidanceState[],
|
||||
g: Graph,
|
||||
bbox: Rect,
|
||||
base: BaseModelType,
|
||||
denoise: Invocation<'denoise_latents'>,
|
||||
posCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>,
|
||||
negCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>,
|
||||
posCondCollect: Invocation<'collect'>,
|
||||
negCondCollect: Invocation<'collect'>,
|
||||
ipAdapterCollect: Invocation<'collect'>
|
||||
): Promise<AddedRegionResult[]> => {
|
||||
const isSDXL = base === 'sdxl';
|
||||
export const addRegions = async ({
|
||||
manager,
|
||||
regions,
|
||||
g,
|
||||
bbox,
|
||||
model,
|
||||
posCond,
|
||||
negCond,
|
||||
posCondCollect,
|
||||
negCondCollect,
|
||||
ipAdapterCollect,
|
||||
}: AddRegionsArg): Promise<AddedRegionResult[]> => {
|
||||
const isSDXL = model.base === 'sdxl';
|
||||
const isFLUX = model.base === 'flux';
|
||||
|
||||
const validRegions = regions.filter((rg) => {
|
||||
if (!rg.isEnabled) {
|
||||
return false;
|
||||
}
|
||||
return getRegionalGuidanceWarnings(rg, model).length === 0;
|
||||
});
|
||||
|
||||
const validRegions = regions.filter((rg) => isValidRegion(rg, base));
|
||||
const results: AddedRegionResult[] = [];
|
||||
|
||||
for (const region of validRegions) {
|
||||
@@ -94,20 +104,27 @@ export const addRegions = async (
|
||||
if (region.positivePrompt) {
|
||||
// The main positive conditioning node
|
||||
result.addedPositivePrompt = true;
|
||||
const regionalPosCond = g.addNode(
|
||||
isSDXL
|
||||
? {
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: getPrefixedId('prompt_region_positive_cond'),
|
||||
prompt: region.positivePrompt,
|
||||
style: region.positivePrompt, // TODO: Should we put the positive prompt in both fields?
|
||||
}
|
||||
: {
|
||||
type: 'compel',
|
||||
id: getPrefixedId('prompt_region_positive_cond'),
|
||||
prompt: region.positivePrompt,
|
||||
}
|
||||
);
|
||||
let regionalPosCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>;
|
||||
if (isSDXL) {
|
||||
regionalPosCond = g.addNode({
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: getPrefixedId('prompt_region_positive_cond'),
|
||||
prompt: region.positivePrompt,
|
||||
style: region.positivePrompt, // TODO: Should we put the positive prompt in both fields?
|
||||
});
|
||||
} else if (isFLUX) {
|
||||
regionalPosCond = g.addNode({
|
||||
type: 'flux_text_encoder',
|
||||
id: getPrefixedId('prompt_region_positive_cond'),
|
||||
prompt: region.positivePrompt,
|
||||
});
|
||||
} else {
|
||||
regionalPosCond = g.addNode({
|
||||
type: 'compel',
|
||||
id: getPrefixedId('prompt_region_positive_cond'),
|
||||
prompt: region.positivePrompt,
|
||||
});
|
||||
}
|
||||
// Connect the mask to the conditioning
|
||||
g.addEdge(maskToTensor, 'mask', regionalPosCond, 'mask');
|
||||
// Connect the conditioning to the collector
|
||||
@@ -115,38 +132,55 @@ export const addRegions = async (
|
||||
// Copy the connections to the "global" positive conditioning node to the regional cond
|
||||
if (posCond.type === 'compel') {
|
||||
for (const edge of g.getEdgesTo(posCond, ['clip', 'mask'])) {
|
||||
// Clone the edge, but change the destination node to the regional conditioning node
|
||||
const clone = deepClone(edge);
|
||||
clone.destination.node_id = regionalPosCond.id;
|
||||
g.addEdgeFromObj(clone);
|
||||
}
|
||||
} else if (posCond.type === 'sdxl_compel_prompt') {
|
||||
for (const edge of g.getEdgesTo(posCond, ['clip', 'clip2', 'mask'])) {
|
||||
const clone = deepClone(edge);
|
||||
clone.destination.node_id = regionalPosCond.id;
|
||||
g.addEdgeFromObj(clone);
|
||||
}
|
||||
} else if (posCond.type === 'flux_text_encoder') {
|
||||
for (const edge of g.getEdgesTo(posCond, ['clip', 't5_encoder', 't5_max_seq_len', 'mask'])) {
|
||||
const clone = deepClone(edge);
|
||||
clone.destination.node_id = regionalPosCond.id;
|
||||
g.addEdgeFromObj(clone);
|
||||
}
|
||||
} else {
|
||||
for (const edge of g.getEdgesTo(posCond, ['clip', 'clip2', 'mask'])) {
|
||||
// Clone the edge, but change the destination node to the regional conditioning node
|
||||
const clone = deepClone(edge);
|
||||
clone.destination.node_id = regionalPosCond.id;
|
||||
g.addEdgeFromObj(clone);
|
||||
}
|
||||
assert(false, 'Unsupported positive conditioning node type.');
|
||||
}
|
||||
}
|
||||
|
||||
if (region.negativePrompt) {
|
||||
result.addedNegativePrompt = true;
|
||||
assert(negCond, 'Negative conditioning node is required if there is a negative prompt');
|
||||
assert(negCondCollect, 'Negative conditioning collector is required if there is a negative prompt');
|
||||
|
||||
// The main negative conditioning node
|
||||
const regionalNegCond = g.addNode(
|
||||
isSDXL
|
||||
? {
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: getPrefixedId('prompt_region_negative_cond'),
|
||||
prompt: region.negativePrompt,
|
||||
style: region.negativePrompt,
|
||||
}
|
||||
: {
|
||||
type: 'compel',
|
||||
id: getPrefixedId('prompt_region_negative_cond'),
|
||||
prompt: region.negativePrompt,
|
||||
}
|
||||
);
|
||||
result.addedNegativePrompt = true;
|
||||
let regionalNegCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>;
|
||||
if (isSDXL) {
|
||||
regionalNegCond = g.addNode({
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: getPrefixedId('prompt_region_negative_cond'),
|
||||
prompt: region.negativePrompt,
|
||||
style: region.negativePrompt,
|
||||
});
|
||||
} else if (isFLUX) {
|
||||
regionalNegCond = g.addNode({
|
||||
type: 'flux_text_encoder',
|
||||
id: getPrefixedId('prompt_region_negative_cond'),
|
||||
prompt: region.negativePrompt,
|
||||
});
|
||||
} else {
|
||||
regionalNegCond = g.addNode({
|
||||
type: 'compel',
|
||||
id: getPrefixedId('prompt_region_negative_cond'),
|
||||
prompt: region.negativePrompt,
|
||||
});
|
||||
}
|
||||
|
||||
// Connect the mask to the conditioning
|
||||
g.addEdge(maskToTensor, 'mask', regionalNegCond, 'mask');
|
||||
// Connect the conditioning to the collector
|
||||
@@ -158,17 +192,27 @@ export const addRegions = async (
|
||||
clone.destination.node_id = regionalNegCond.id;
|
||||
g.addEdgeFromObj(clone);
|
||||
}
|
||||
} else {
|
||||
} else if (negCond.type === 'sdxl_compel_prompt') {
|
||||
for (const edge of g.getEdgesTo(negCond, ['clip', 'clip2', 'mask'])) {
|
||||
const clone = deepClone(edge);
|
||||
clone.destination.node_id = regionalNegCond.id;
|
||||
g.addEdgeFromObj(clone);
|
||||
}
|
||||
} else if (negCond.type === 'flux_text_encoder') {
|
||||
for (const edge of g.getEdgesTo(negCond, ['clip', 't5_encoder', 't5_max_seq_len', 'mask'])) {
|
||||
const clone = deepClone(edge);
|
||||
clone.destination.node_id = regionalNegCond.id;
|
||||
g.addEdgeFromObj(clone);
|
||||
}
|
||||
} else {
|
||||
assert(false, 'Unsupported negative conditioning node type.');
|
||||
}
|
||||
}
|
||||
|
||||
// If we are using the "invert" auto-negative setting, we need to add an additional negative conditioning node
|
||||
if (region.autoNegative && region.positivePrompt) {
|
||||
assert(negCondCollect, 'Negative conditioning collector is required if there is an auto-negative setting');
|
||||
|
||||
result.addedAutoNegativePositivePrompt = true;
|
||||
// We re-use the mask image, but invert it when converting to tensor
|
||||
const invertTensorMask = g.addNode({
|
||||
@@ -178,20 +222,27 @@ export const addRegions = async (
|
||||
// Connect the OG mask image to the inverted mask-to-tensor node
|
||||
g.addEdge(maskToTensor, 'mask', invertTensorMask, 'mask');
|
||||
// Create the conditioning node. It's going to be connected to the negative cond collector, but it uses the positive prompt
|
||||
const regionalPosCondInverted = g.addNode(
|
||||
isSDXL
|
||||
? {
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: getPrefixedId('prompt_region_positive_cond_inverted'),
|
||||
prompt: region.positivePrompt,
|
||||
style: region.positivePrompt,
|
||||
}
|
||||
: {
|
||||
type: 'compel',
|
||||
id: getPrefixedId('prompt_region_positive_cond_inverted'),
|
||||
prompt: region.positivePrompt,
|
||||
}
|
||||
);
|
||||
let regionalPosCondInverted: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>;
|
||||
if (isSDXL) {
|
||||
regionalPosCondInverted = g.addNode({
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: getPrefixedId('prompt_region_positive_cond_inverted'),
|
||||
prompt: region.positivePrompt,
|
||||
style: region.positivePrompt,
|
||||
});
|
||||
} else if (isFLUX) {
|
||||
regionalPosCondInverted = g.addNode({
|
||||
type: 'flux_text_encoder',
|
||||
id: getPrefixedId('prompt_region_positive_cond_inverted'),
|
||||
prompt: region.positivePrompt,
|
||||
});
|
||||
} else {
|
||||
regionalPosCondInverted = g.addNode({
|
||||
type: 'compel',
|
||||
id: getPrefixedId('prompt_region_positive_cond_inverted'),
|
||||
prompt: region.positivePrompt,
|
||||
});
|
||||
}
|
||||
// Connect the inverted mask to the conditioning
|
||||
g.addEdge(invertTensorMask, 'mask', regionalPosCondInverted, 'mask');
|
||||
// Connect the conditioning to the negative collector
|
||||
@@ -203,20 +254,26 @@ export const addRegions = async (
|
||||
clone.destination.node_id = regionalPosCondInverted.id;
|
||||
g.addEdgeFromObj(clone);
|
||||
}
|
||||
} else {
|
||||
} else if (posCond.type === 'sdxl_compel_prompt') {
|
||||
for (const edge of g.getEdgesTo(posCond, ['clip', 'clip2', 'mask'])) {
|
||||
const clone = deepClone(edge);
|
||||
clone.destination.node_id = regionalPosCondInverted.id;
|
||||
g.addEdgeFromObj(clone);
|
||||
}
|
||||
} else if (posCond.type === 'flux_text_encoder') {
|
||||
for (const edge of g.getEdgesTo(posCond, ['clip', 't5_encoder', 't5_max_seq_len', 'mask'])) {
|
||||
const clone = deepClone(edge);
|
||||
clone.destination.node_id = regionalPosCondInverted.id;
|
||||
g.addEdgeFromObj(clone);
|
||||
}
|
||||
} else {
|
||||
assert(false, 'Unsupported positive conditioning node type.');
|
||||
}
|
||||
}
|
||||
|
||||
const validRGIPAdapters: RegionalGuidanceReferenceImageState[] = region.referenceImages.filter(({ ipAdapter }) =>
|
||||
isValidIPAdapter(ipAdapter, base)
|
||||
);
|
||||
for (const { id, ipAdapter } of region.referenceImages) {
|
||||
assert(!isFLUX, 'Regional IP adapters are not supported for FLUX.');
|
||||
|
||||
for (const { id, ipAdapter } of validRGIPAdapters) {
|
||||
result.addedIPAdapters++;
|
||||
const { weight, model, clipVisionModel, method, beginEndStepPct, image } = ipAdapter;
|
||||
assert(model, 'IP Adapter model is required');
|
||||
@@ -248,11 +305,3 @@ export const addRegions = async (
|
||||
|
||||
return results;
|
||||
};
|
||||
|
||||
const isValidIPAdapter = (ipAdapter: IPAdapterConfig, base: BaseModelType): boolean => {
|
||||
// Must be have a model that matches the current base and must have a control image
|
||||
const hasModel = Boolean(ipAdapter.model);
|
||||
const modelMatchesBase = ipAdapter.model?.base === base;
|
||||
const hasImage = Boolean(ipAdapter.image);
|
||||
return hasModel && modelMatchesBase && hasImage;
|
||||
};
|
||||
|
||||
@@ -11,6 +11,7 @@ import { addImageToImage } from 'features/nodes/util/graph/generation/addImageTo
|
||||
import { addInpaint } from 'features/nodes/util/graph/generation/addInpaint';
|
||||
import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChecker';
|
||||
import { addOutpaint } from 'features/nodes/util/graph/generation/addOutpaint';
|
||||
import { addRegions } from 'features/nodes/util/graph/generation/addRegions';
|
||||
import { addTextToImage } from 'features/nodes/util/graph/generation/addTextToImage';
|
||||
import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker';
|
||||
import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
@@ -79,7 +80,10 @@ export const buildFLUXGraph = async (
|
||||
id: getPrefixedId('flux_text_encoder'),
|
||||
prompt: positivePrompt,
|
||||
});
|
||||
|
||||
const posCondCollect = g.addNode({
|
||||
type: 'collect',
|
||||
id: getPrefixedId('pos_cond_collect'),
|
||||
});
|
||||
const denoise = g.addNode({
|
||||
type: 'flux_denoise',
|
||||
id: getPrefixedId('flux_denoise'),
|
||||
@@ -104,13 +108,12 @@ export const buildFLUXGraph = async (
|
||||
g.addEdge(modelLoader, 'clip', posCond, 'clip');
|
||||
g.addEdge(modelLoader, 't5_encoder', posCond, 't5_encoder');
|
||||
g.addEdge(modelLoader, 'max_seq_len', posCond, 't5_max_seq_len');
|
||||
g.addEdge(posCond, 'conditioning', posCondCollect, 'item');
|
||||
g.addEdge(posCondCollect, 'collection', denoise, 'positive_text_conditioning');
|
||||
g.addEdge(denoise, 'latents', l2i, 'latents');
|
||||
|
||||
addFLUXLoRAs(state, g, denoise, modelLoader, posCond);
|
||||
|
||||
g.addEdge(posCond, 'conditioning', denoise, 'positive_text_conditioning');
|
||||
|
||||
g.addEdge(denoise, 'latents', l2i, 'latents');
|
||||
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||
assert(modelConfig.base === 'flux');
|
||||
|
||||
@@ -196,31 +199,50 @@ export const buildFLUXGraph = async (
|
||||
type: 'collect',
|
||||
id: getPrefixedId('control_net_collector'),
|
||||
});
|
||||
const controlNetResult = await addControlNets(
|
||||
const controlNetResult = await addControlNets({
|
||||
manager,
|
||||
canvas.controlLayers.entities,
|
||||
entities: canvas.controlLayers.entities,
|
||||
g,
|
||||
canvas.bbox.rect,
|
||||
controlNetCollector,
|
||||
modelConfig.base
|
||||
);
|
||||
rect: canvas.bbox.rect,
|
||||
collector: controlNetCollector,
|
||||
model: modelConfig,
|
||||
});
|
||||
if (controlNetResult.addedControlNets > 0) {
|
||||
g.addEdge(controlNetCollector, 'collection', denoise, 'control');
|
||||
} else {
|
||||
g.deleteNode(controlNetCollector.id);
|
||||
}
|
||||
|
||||
const ipAdapterCollector = g.addNode({
|
||||
const ipAdapterCollect = g.addNode({
|
||||
type: 'collect',
|
||||
id: getPrefixedId('ip_adapter_collector'),
|
||||
});
|
||||
const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollector, modelConfig.base);
|
||||
const ipAdapterResult = addIPAdapters({
|
||||
entities: canvas.referenceImages.entities,
|
||||
g,
|
||||
collector: ipAdapterCollect,
|
||||
model: modelConfig,
|
||||
});
|
||||
|
||||
const totalIPAdaptersAdded = ipAdapterResult.addedIPAdapters;
|
||||
const regionsResult = await addRegions({
|
||||
manager,
|
||||
regions: canvas.regionalGuidance.entities,
|
||||
g,
|
||||
bbox: canvas.bbox.rect,
|
||||
model: modelConfig,
|
||||
posCond,
|
||||
negCond: null,
|
||||
posCondCollect,
|
||||
negCondCollect: null,
|
||||
ipAdapterCollect,
|
||||
});
|
||||
|
||||
const totalIPAdaptersAdded =
|
||||
ipAdapterResult.addedIPAdapters + regionsResult.reduce((acc, r) => acc + r.addedIPAdapters, 0);
|
||||
if (totalIPAdaptersAdded > 0) {
|
||||
g.addEdge(ipAdapterCollector, 'collection', denoise, 'ip_adapter');
|
||||
g.addEdge(ipAdapterCollect, 'collection', denoise, 'ip_adapter');
|
||||
} else {
|
||||
g.deleteNode(ipAdapterCollector.id);
|
||||
g.deleteNode(ipAdapterCollect.id);
|
||||
}
|
||||
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
|
||||
@@ -227,14 +227,14 @@ export const buildSD1Graph = async (
|
||||
type: 'collect',
|
||||
id: getPrefixedId('control_net_collector'),
|
||||
});
|
||||
const controlNetResult = await addControlNets(
|
||||
const controlNetResult = await addControlNets({
|
||||
manager,
|
||||
canvas.controlLayers.entities,
|
||||
entities: canvas.controlLayers.entities,
|
||||
g,
|
||||
canvas.bbox.rect,
|
||||
controlNetCollector,
|
||||
modelConfig.base
|
||||
);
|
||||
rect: canvas.bbox.rect,
|
||||
collector: controlNetCollector,
|
||||
model: modelConfig,
|
||||
});
|
||||
if (controlNetResult.addedControlNets > 0) {
|
||||
g.addEdge(controlNetCollector, 'collection', denoise, 'control');
|
||||
} else {
|
||||
@@ -245,46 +245,50 @@ export const buildSD1Graph = async (
|
||||
type: 'collect',
|
||||
id: getPrefixedId('t2i_adapter_collector'),
|
||||
});
|
||||
const t2iAdapterResult = await addT2IAdapters(
|
||||
const t2iAdapterResult = await addT2IAdapters({
|
||||
manager,
|
||||
canvas.controlLayers.entities,
|
||||
entities: canvas.controlLayers.entities,
|
||||
g,
|
||||
canvas.bbox.rect,
|
||||
t2iAdapterCollector,
|
||||
modelConfig.base
|
||||
);
|
||||
rect: canvas.bbox.rect,
|
||||
collector: t2iAdapterCollector,
|
||||
model: modelConfig,
|
||||
});
|
||||
if (t2iAdapterResult.addedT2IAdapters > 0) {
|
||||
g.addEdge(t2iAdapterCollector, 'collection', denoise, 't2i_adapter');
|
||||
} else {
|
||||
g.deleteNode(t2iAdapterCollector.id);
|
||||
}
|
||||
|
||||
const ipAdapterCollector = g.addNode({
|
||||
const ipAdapterCollect = g.addNode({
|
||||
type: 'collect',
|
||||
id: getPrefixedId('ip_adapter_collector'),
|
||||
});
|
||||
const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollector, modelConfig.base);
|
||||
|
||||
const regionsResult = await addRegions(
|
||||
manager,
|
||||
canvas.regionalGuidance.entities,
|
||||
const ipAdapterResult = addIPAdapters({
|
||||
entities: canvas.referenceImages.entities,
|
||||
g,
|
||||
canvas.bbox.rect,
|
||||
modelConfig.base,
|
||||
denoise,
|
||||
collector: ipAdapterCollect,
|
||||
model: modelConfig,
|
||||
});
|
||||
|
||||
const regionsResult = await addRegions({
|
||||
manager,
|
||||
regions: canvas.regionalGuidance.entities,
|
||||
g,
|
||||
bbox: canvas.bbox.rect,
|
||||
model: modelConfig,
|
||||
posCond,
|
||||
negCond,
|
||||
posCondCollect,
|
||||
negCondCollect,
|
||||
ipAdapterCollector
|
||||
);
|
||||
ipAdapterCollect,
|
||||
});
|
||||
|
||||
const totalIPAdaptersAdded =
|
||||
ipAdapterResult.addedIPAdapters + regionsResult.reduce((acc, r) => acc + r.addedIPAdapters, 0);
|
||||
if (totalIPAdaptersAdded > 0) {
|
||||
g.addEdge(ipAdapterCollector, 'collection', denoise, 'ip_adapter');
|
||||
g.addEdge(ipAdapterCollect, 'collection', denoise, 'ip_adapter');
|
||||
} else {
|
||||
g.deleteNode(ipAdapterCollector.id);
|
||||
g.deleteNode(ipAdapterCollect.id);
|
||||
}
|
||||
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
|
||||
@@ -232,14 +232,14 @@ export const buildSDXLGraph = async (
|
||||
type: 'collect',
|
||||
id: getPrefixedId('control_net_collector'),
|
||||
});
|
||||
const controlNetResult = await addControlNets(
|
||||
const controlNetResult = await addControlNets({
|
||||
manager,
|
||||
canvas.controlLayers.entities,
|
||||
entities: canvas.controlLayers.entities,
|
||||
g,
|
||||
canvas.bbox.rect,
|
||||
controlNetCollector,
|
||||
modelConfig.base
|
||||
);
|
||||
rect: canvas.bbox.rect,
|
||||
collector: controlNetCollector,
|
||||
model: modelConfig,
|
||||
});
|
||||
if (controlNetResult.addedControlNets > 0) {
|
||||
g.addEdge(controlNetCollector, 'collection', denoise, 'control');
|
||||
} else {
|
||||
@@ -250,46 +250,50 @@ export const buildSDXLGraph = async (
|
||||
type: 'collect',
|
||||
id: getPrefixedId('t2i_adapter_collector'),
|
||||
});
|
||||
const t2iAdapterResult = await addT2IAdapters(
|
||||
const t2iAdapterResult = await addT2IAdapters({
|
||||
manager,
|
||||
canvas.controlLayers.entities,
|
||||
entities: canvas.controlLayers.entities,
|
||||
g,
|
||||
canvas.bbox.rect,
|
||||
t2iAdapterCollector,
|
||||
modelConfig.base
|
||||
);
|
||||
rect: canvas.bbox.rect,
|
||||
collector: t2iAdapterCollector,
|
||||
model: modelConfig,
|
||||
});
|
||||
if (t2iAdapterResult.addedT2IAdapters > 0) {
|
||||
g.addEdge(t2iAdapterCollector, 'collection', denoise, 't2i_adapter');
|
||||
} else {
|
||||
g.deleteNode(t2iAdapterCollector.id);
|
||||
}
|
||||
|
||||
const ipAdapterCollector = g.addNode({
|
||||
const ipAdapterCollect = g.addNode({
|
||||
type: 'collect',
|
||||
id: getPrefixedId('ip_adapter_collector'),
|
||||
});
|
||||
const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollector, modelConfig.base);
|
||||
|
||||
const regionsResult = await addRegions(
|
||||
manager,
|
||||
canvas.regionalGuidance.entities,
|
||||
const ipAdapterResult = addIPAdapters({
|
||||
entities: canvas.referenceImages.entities,
|
||||
g,
|
||||
canvas.bbox.rect,
|
||||
modelConfig.base,
|
||||
denoise,
|
||||
collector: ipAdapterCollect,
|
||||
model: modelConfig,
|
||||
});
|
||||
|
||||
const regionsResult = await addRegions({
|
||||
manager,
|
||||
regions: canvas.regionalGuidance.entities,
|
||||
g,
|
||||
bbox: canvas.bbox.rect,
|
||||
model: modelConfig,
|
||||
posCond,
|
||||
negCond,
|
||||
posCondCollect,
|
||||
negCondCollect,
|
||||
ipAdapterCollector
|
||||
);
|
||||
ipAdapterCollect,
|
||||
});
|
||||
|
||||
const totalIPAdaptersAdded =
|
||||
ipAdapterResult.addedIPAdapters + regionsResult.reduce((acc, r) => acc + r.addedIPAdapters, 0);
|
||||
if (totalIPAdaptersAdded > 0) {
|
||||
g.addEdge(ipAdapterCollector, 'collection', denoise, 'ip_adapter');
|
||||
g.addEdge(ipAdapterCollect, 'collection', denoise, 'ip_adapter');
|
||||
} else {
|
||||
g.deleteNode(ipAdapterCollector.id);
|
||||
g.deleteNode(ipAdapterCollect.id);
|
||||
}
|
||||
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
|
||||
@@ -4,13 +4,13 @@ import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiTrashSimpleFill } from 'react-icons/pi';
|
||||
|
||||
import { useClearQueue } from './ClearQueueConfirmationAlertDialog';
|
||||
import { useClearQueueDialog } from './ClearQueueConfirmationAlertDialog';
|
||||
|
||||
type Props = ButtonProps;
|
||||
|
||||
const ClearQueueButton = (props: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const clearQueue = useClearQueue();
|
||||
const clearQueue = useClearQueueDialog();
|
||||
|
||||
return (
|
||||
<>
|
||||
|
||||
@@ -1,51 +1,15 @@
|
||||
import { ConfirmationAlertDialog, Text } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
|
||||
import { buildUseBoolean } from 'common/hooks/useBoolean';
|
||||
import { listCursorChanged, listPriorityChanged } from 'features/queue/store/queueSlice';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useClearQueue } from 'features/queue/hooks/useClearQueue';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useClearQueueMutation, useGetQueueStatusQuery } from 'services/api/endpoints/queue';
|
||||
import { $isConnected } from 'services/events/stores';
|
||||
|
||||
const [useClearQueueConfirmationAlertDialog] = buildUseBoolean(false);
|
||||
|
||||
export const useClearQueue = () => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
export const useClearQueueDialog = () => {
|
||||
const dialog = useClearQueueConfirmationAlertDialog();
|
||||
const { data: queueStatus } = useGetQueueStatusQuery();
|
||||
const isConnected = useStore($isConnected);
|
||||
const [trigger, { isLoading }] = useClearQueueMutation({
|
||||
fixedCacheKey: 'clearQueue',
|
||||
});
|
||||
|
||||
const clearQueue = useCallback(async () => {
|
||||
if (!queueStatus?.queue.total) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
await trigger().unwrap();
|
||||
toast({
|
||||
id: 'QUEUE_CLEAR_SUCCEEDED',
|
||||
title: t('queue.clearSucceeded'),
|
||||
status: 'success',
|
||||
});
|
||||
dispatch(listCursorChanged(undefined));
|
||||
dispatch(listPriorityChanged(undefined));
|
||||
} catch {
|
||||
toast({
|
||||
id: 'QUEUE_CLEAR_FAILED',
|
||||
title: t('queue.clearFailed'),
|
||||
status: 'error',
|
||||
});
|
||||
}
|
||||
}, [queueStatus?.queue.total, trigger, dispatch, t]);
|
||||
|
||||
const isDisabled = useMemo(() => !isConnected || !queueStatus?.queue.total, [isConnected, queueStatus?.queue.total]);
|
||||
const { clearQueue, isLoading, isDisabled, queueStatus } = useClearQueue();
|
||||
|
||||
return {
|
||||
clearQueue,
|
||||
@@ -61,7 +25,7 @@ export const useClearQueue = () => {
|
||||
export const ClearQueueConfirmationsAlertDialog = memo(() => {
|
||||
useAssertSingleton('ClearQueueConfirmationsAlertDialog');
|
||||
const { t } = useTranslation();
|
||||
const clearQueue = useClearQueue();
|
||||
const clearQueue = useClearQueueDialog();
|
||||
|
||||
return (
|
||||
<ConfirmationAlertDialog
|
||||
|
||||
@@ -4,11 +4,11 @@ import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiTrashSimpleBold, PiXBold } from 'react-icons/pi';
|
||||
|
||||
import { useClearQueue } from './ClearQueueConfirmationAlertDialog';
|
||||
import { useClearQueueDialog } from './ClearQueueConfirmationAlertDialog';
|
||||
|
||||
export const ClearQueueIconButton = memo((_) => {
|
||||
const { t } = useTranslation();
|
||||
const clearQueue = useClearQueue();
|
||||
const clearQueue = useClearQueueDialog();
|
||||
const cancelCurrentQueueItem = useCancelCurrentQueueItem();
|
||||
|
||||
// Show the single item clear button when shift is pressed
|
||||
|
||||
@@ -147,8 +147,6 @@ const UpscaleTabTooltipContent = memo(({ prepend = false }: { prepend?: boolean
|
||||
<ReasonsList reasons={reasons} />
|
||||
</>
|
||||
)}
|
||||
<StyledDivider />
|
||||
<AddingToText />
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
@@ -180,8 +178,6 @@ const WorkflowsTabTooltipContent = memo(({ prepend = false }: { prepend?: boolea
|
||||
<ReasonsList reasons={reasons} />
|
||||
</>
|
||||
)}
|
||||
<StyledDivider />
|
||||
<AddingToText />
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import { IconButton, Menu, MenuButton, MenuGroup, MenuItem, MenuList } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { SessionMenuItems } from 'common/components/SessionMenuItems';
|
||||
import { useClearQueue } from 'features/queue/components/ClearQueueConfirmationAlertDialog';
|
||||
import { useClearQueueDialog } from 'features/queue/components/ClearQueueConfirmationAlertDialog';
|
||||
import { QueueCountBadge } from 'features/queue/components/QueueCountBadge';
|
||||
import { useCancelCurrentQueueItem } from 'features/queue/hooks/useCancelCurrentQueueItem';
|
||||
import { usePauseProcessor } from 'features/queue/hooks/usePauseProcessor';
|
||||
import { useResumeProcessor } from 'features/queue/hooks/useResumeProcessor';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
@@ -17,7 +18,8 @@ export const QueueActionsMenuButton = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const isPauseEnabled = useFeatureStatus('pauseQueue');
|
||||
const isResumeEnabled = useFeatureStatus('resumeQueue');
|
||||
const clearQueue = useClearQueue();
|
||||
const cancelCurrent = useCancelCurrentQueueItem();
|
||||
const clearQueue = useClearQueueDialog();
|
||||
const {
|
||||
resumeProcessor,
|
||||
isLoading: isLoadingResumeProcessor,
|
||||
@@ -44,9 +46,9 @@ export const QueueActionsMenuButton = memo(() => {
|
||||
<MenuItem
|
||||
isDestructive
|
||||
icon={<PiXBold />}
|
||||
onClick={clearQueue.openDialog}
|
||||
isLoading={clearQueue.isLoading}
|
||||
isDisabled={clearQueue.isDisabled}
|
||||
onClick={cancelCurrent.cancelQueueItem}
|
||||
isLoading={cancelCurrent.isLoading}
|
||||
isDisabled={cancelCurrent.isDisabled}
|
||||
>
|
||||
{t('queue.cancelTooltip')}
|
||||
</MenuItem>
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { listCursorChanged, listPriorityChanged } from 'features/queue/store/queueSlice';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useClearQueueMutation, useGetQueueStatusQuery } from 'services/api/endpoints/queue';
|
||||
import { $isConnected } from 'services/events/stores';
|
||||
|
||||
export const useClearQueue = () => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const { data: queueStatus } = useGetQueueStatusQuery();
|
||||
const isConnected = useStore($isConnected);
|
||||
const [trigger, { isLoading }] = useClearQueueMutation({
|
||||
fixedCacheKey: 'clearQueue',
|
||||
});
|
||||
|
||||
const clearQueue = useCallback(async () => {
|
||||
if (!queueStatus?.queue.total) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
await trigger().unwrap();
|
||||
toast({
|
||||
id: 'QUEUE_CLEAR_SUCCEEDED',
|
||||
title: t('queue.clearSucceeded'),
|
||||
status: 'success',
|
||||
});
|
||||
dispatch(listCursorChanged(undefined));
|
||||
dispatch(listPriorityChanged(undefined));
|
||||
} catch {
|
||||
toast({
|
||||
id: 'QUEUE_CLEAR_FAILED',
|
||||
title: t('queue.clearFailed'),
|
||||
status: 'error',
|
||||
});
|
||||
}
|
||||
}, [queueStatus?.queue.total, trigger, dispatch, t]);
|
||||
|
||||
const isDisabled = useMemo(() => !isConnected || !queueStatus?.queue.total, [isConnected, queueStatus?.queue.total]);
|
||||
|
||||
return {
|
||||
clearQueue,
|
||||
isLoading,
|
||||
queueStatus,
|
||||
isDisabled,
|
||||
};
|
||||
};
|
||||
@@ -4,6 +4,13 @@ import type { ParamsState } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
import type { CanvasState } from 'features/controlLayers/store/types';
|
||||
import {
|
||||
getControlLayerWarnings,
|
||||
getGlobalReferenceImageWarnings,
|
||||
getInpaintMaskWarnings,
|
||||
getRasterLayerWarnings,
|
||||
getRegionalGuidanceWarnings,
|
||||
} from 'features/controlLayers/store/validators';
|
||||
import type { DynamicPromptsState } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
|
||||
import { selectDynamicPromptsSlice } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
|
||||
import { getShouldProcessPrompt } from 'features/dynamicPrompts/util/getShouldProcessPrompt';
|
||||
@@ -278,17 +285,10 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: {
|
||||
const layerNumber = i + 1;
|
||||
const layerType = i18n.t(LAYER_TYPE_TO_TKEY['control_layer']);
|
||||
const prefix = `${layerLiteral} #${layerNumber} (${layerType})`;
|
||||
const problems: string[] = [];
|
||||
// Must have model
|
||||
if (!controlLayer.controlAdapter.model) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.controlAdapterNoModelSelected'));
|
||||
}
|
||||
// Model base must match
|
||||
if (controlLayer.controlAdapter.model?.base !== model?.base) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.controlAdapterIncompatibleBaseModel'));
|
||||
}
|
||||
const problems = getControlLayerWarnings(controlLayer, model);
|
||||
|
||||
if (problems.length) {
|
||||
const content = upperFirst(problems.join(', '));
|
||||
const content = upperFirst(problems.map((p) => i18n.t(p)).join(', '));
|
||||
reasons.push({ prefix, content });
|
||||
}
|
||||
});
|
||||
@@ -300,23 +300,10 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: {
|
||||
const layerNumber = i + 1;
|
||||
const layerType = i18n.t(LAYER_TYPE_TO_TKEY[entity.type]);
|
||||
const prefix = `${layerLiteral} #${layerNumber} (${layerType})`;
|
||||
const problems: string[] = [];
|
||||
|
||||
// Must have model
|
||||
if (!entity.ipAdapter.model) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoModelSelected'));
|
||||
}
|
||||
// Model base must match
|
||||
if (entity.ipAdapter.model?.base !== model?.base) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.ipAdapterIncompatibleBaseModel'));
|
||||
}
|
||||
// Must have an image
|
||||
if (!entity.ipAdapter.image) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoImageSelected'));
|
||||
}
|
||||
const problems = getGlobalReferenceImageWarnings(entity, model);
|
||||
|
||||
if (problems.length) {
|
||||
const content = upperFirst(problems.join(', '));
|
||||
const content = upperFirst(problems.map((p) => i18n.t(p)).join(', '));
|
||||
reasons.push({ prefix, content });
|
||||
}
|
||||
});
|
||||
@@ -328,32 +315,10 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: {
|
||||
const layerNumber = i + 1;
|
||||
const layerType = i18n.t(LAYER_TYPE_TO_TKEY[entity.type]);
|
||||
const prefix = `${layerLiteral} #${layerNumber} (${layerType})`;
|
||||
const problems: string[] = [];
|
||||
// Must have a region
|
||||
if (entity.objects.length === 0) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.rgNoRegion'));
|
||||
}
|
||||
// Must have at least 1 prompt or IP Adapter
|
||||
if (entity.positivePrompt === null && entity.negativePrompt === null && entity.referenceImages.length === 0) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.rgNoPromptsOrIPAdapters'));
|
||||
}
|
||||
entity.referenceImages.forEach(({ ipAdapter }) => {
|
||||
// Must have model
|
||||
if (!ipAdapter.model) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoModelSelected'));
|
||||
}
|
||||
// Model base must match
|
||||
if (ipAdapter.model?.base !== model?.base) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.ipAdapterIncompatibleBaseModel'));
|
||||
}
|
||||
// Must have an image
|
||||
if (!ipAdapter.image) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoImageSelected'));
|
||||
}
|
||||
});
|
||||
const problems = getRegionalGuidanceWarnings(entity, model);
|
||||
|
||||
if (problems.length) {
|
||||
const content = upperFirst(problems.join(', '));
|
||||
const content = upperFirst(problems.map((p) => i18n.t(p)).join(', '));
|
||||
reasons.push({ prefix, content });
|
||||
}
|
||||
});
|
||||
@@ -365,10 +330,25 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: {
|
||||
const layerNumber = i + 1;
|
||||
const layerType = i18n.t(LAYER_TYPE_TO_TKEY[entity.type]);
|
||||
const prefix = `${layerLiteral} #${layerNumber} (${layerType})`;
|
||||
const problems: string[] = [];
|
||||
const problems = getRasterLayerWarnings(entity, model);
|
||||
|
||||
if (problems.length) {
|
||||
const content = upperFirst(problems.join(', '));
|
||||
const content = upperFirst(problems.map((p) => i18n.t(p)).join(', '));
|
||||
reasons.push({ prefix, content });
|
||||
}
|
||||
});
|
||||
|
||||
canvas.inpaintMasks.entities
|
||||
.filter((entity) => entity.isEnabled)
|
||||
.forEach((entity, i) => {
|
||||
const layerLiteral = i18n.t('controlLayers.layer_one');
|
||||
const layerNumber = i + 1;
|
||||
const layerType = i18n.t(LAYER_TYPE_TO_TKEY[entity.type]);
|
||||
const prefix = `${layerLiteral} #${layerNumber} (${layerType})`;
|
||||
const problems = getInpaintMaskWarnings(entity, model);
|
||||
|
||||
if (problems.length) {
|
||||
const content = upperFirst(problems.map((p) => i18n.t(p)).join(', '));
|
||||
reasons.push({ prefix, content });
|
||||
}
|
||||
});
|
||||
|
||||
@@ -31,6 +31,7 @@ const optionsObject: Record<Language, string> = {
|
||||
sv: 'Svenska',
|
||||
tr: 'Türkçe',
|
||||
ua: 'Украї́нська',
|
||||
vi: 'tiếng Việt',
|
||||
zh_CN: '简体中文',
|
||||
zh_Hant: '漢語',
|
||||
};
|
||||
|
||||
@@ -22,6 +22,7 @@ const zLanguage = z.enum([
|
||||
'sv',
|
||||
'tr',
|
||||
'ua',
|
||||
'vi',
|
||||
'zh_CN',
|
||||
'zh_Hant',
|
||||
]);
|
||||
|
||||
@@ -3,7 +3,7 @@ import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { ToolChooser } from 'features/controlLayers/components/Tool/ToolChooser';
|
||||
import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
|
||||
import { useClearQueue } from 'features/queue/components/ClearQueueConfirmationAlertDialog';
|
||||
import { useClearQueueDialog } from 'features/queue/components/ClearQueueConfirmationAlertDialog';
|
||||
import { InvokeButtonTooltip } from 'features/queue/components/InvokeButtonTooltip/InvokeButtonTooltip';
|
||||
import { useCancelCurrentQueueItem } from 'features/queue/hooks/useCancelCurrentQueueItem';
|
||||
import { useInvoke } from 'features/queue/hooks/useInvoke';
|
||||
@@ -31,7 +31,7 @@ const FloatingSidePanelButtons = (props: Props) => {
|
||||
const shift = useShiftModifier();
|
||||
const tab = useAppSelector(selectActiveTab);
|
||||
const imageViewer = useImageViewer();
|
||||
const clearQueue = useClearQueue();
|
||||
const clearQueue = useClearQueueDialog();
|
||||
const { data: queueStatus } = useGetQueueStatusQuery();
|
||||
const cancelCurrent = useCancelCurrentQueueItem();
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import type { StartQueryActionCreatorOptions } from '@reduxjs/toolkit/dist/query/core/buildInitiate';
|
||||
import { $authToken } from 'app/store/nanostores/authToken';
|
||||
import { getStore } from 'app/store/nanostores/store';
|
||||
import type { BoardId } from 'features/gallery/store/types';
|
||||
import { ASSETS_CATEGORIES, IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
||||
@@ -624,3 +625,20 @@ export const uploadImages = async (args: UploadImageArg[]): Promise<ImageDTO[]>
|
||||
);
|
||||
return results.filter((r): r is PromiseFulfilledResult<ImageDTO> => r.status === 'fulfilled').map((r) => r.value);
|
||||
};
|
||||
|
||||
/**
|
||||
* Convert an ImageDTO to a File by downloading the image from the server.
|
||||
* @param imageDTO The image to download and convert to a File
|
||||
*/
|
||||
export const imageDTOToFile = async (imageDTO: ImageDTO): Promise<File> => {
|
||||
const init: RequestInit = {};
|
||||
const authToken = $authToken.get();
|
||||
if (authToken) {
|
||||
init.headers = { Authorization: `Bearer ${authToken}` };
|
||||
}
|
||||
const res = await fetch(imageDTO.image_url, init);
|
||||
const blob = await res.blob();
|
||||
// Create a new file with the same name, which we will upload
|
||||
const file = new File([blob], `copy_of_${imageDTO.image_name}`, { type: 'image/png' });
|
||||
return file;
|
||||
};
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -56,7 +56,7 @@ dependencies = [
|
||||
"torchmetrics",
|
||||
"torchsde",
|
||||
"torchvision",
|
||||
"transformers==4.41.1",
|
||||
"transformers==4.46.3",
|
||||
|
||||
# Core application dependencies, pinned for reproducible builds.
|
||||
"fastapi-events==0.11.1",
|
||||
|
||||
24
tests/backend/model_cache_v2/test_torch_autocast_context.py
Normal file
24
tests/backend/model_cache_v2/test_torch_autocast_context.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_cache_v2.torch_autocast_context import TorchAutocastContext
|
||||
|
||||
|
||||
class DummyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(10, 10)
|
||||
self.linear2 = torch.nn.Linear(10, 10)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.linear1(x)
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
|
||||
|
||||
def test_torch_autocast_context():
|
||||
model = DummyModule()
|
||||
|
||||
with TorchAutocastContext(to_device=torch.device("cuda")):
|
||||
x = torch.randn(10, 10, device="cuda")
|
||||
y = model(x)
|
||||
print(y.shape)
|
||||
@@ -0,0 +1,13 @@
|
||||
import torch
|
||||
|
||||
|
||||
class DummyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(10, 10)
|
||||
self.linear2 = torch.nn.Linear(10, 10)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.linear1(x)
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
@@ -0,0 +1,50 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_only_full_load import (
|
||||
CachedModelOnlyFullLoad,
|
||||
)
|
||||
from tests.backend.model_manager.load.model_cache.cached_model.dummy_module import DummyModule
|
||||
|
||||
parameterize_mps_and_cuda = pytest.mark.parametrize(
|
||||
("device"),
|
||||
[
|
||||
pytest.param(
|
||||
"mps", marks=pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS is not available.")
|
||||
),
|
||||
pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available.")),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_total_bytes(device: str):
|
||||
model = DummyModule()
|
||||
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
|
||||
assert cached_model.total_bytes() == 100
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_is_in_vram(device: str):
|
||||
model = DummyModule()
|
||||
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
|
||||
assert not cached_model.is_in_vram()
|
||||
|
||||
cached_model.full_load_to_vram()
|
||||
assert cached_model.is_in_vram()
|
||||
|
||||
cached_model.full_unload_from_vram()
|
||||
assert not cached_model.is_in_vram()
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_full_load_and_unload(device: str):
|
||||
model = DummyModule()
|
||||
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
|
||||
assert cached_model.full_load_to_vram() == 100
|
||||
assert cached_model.is_in_vram()
|
||||
assert all(p.device.type == device for p in cached_model.model.parameters())
|
||||
|
||||
assert cached_model.full_unload_from_vram() == 100
|
||||
assert not cached_model.is_in_vram()
|
||||
assert all(p.device.type == "cpu" for p in cached_model.model.parameters())
|
||||
@@ -0,0 +1,72 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import (
|
||||
CachedModelWithPartialLoad,
|
||||
)
|
||||
from tests.backend.model_manager.load.model_cache.cached_model.dummy_module import DummyModule
|
||||
|
||||
parameterize_mps_and_cuda = pytest.mark.parametrize(
|
||||
("device"),
|
||||
[
|
||||
pytest.param(
|
||||
"mps", marks=pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS is not available.")
|
||||
),
|
||||
pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available.")),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_total_bytes(device: str):
|
||||
if device == "cuda" and not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is not available.")
|
||||
if device == "mps" and not torch.backends.mps.is_available():
|
||||
pytest.skip("MPS is not available.")
|
||||
|
||||
model = DummyModule()
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
linear_numel = 10 * 10 + 10
|
||||
assert cached_model.total_bytes() == linear_numel * 4 * 2
|
||||
|
||||
cached_model.model.to(dtype=torch.float16)
|
||||
assert cached_model.total_bytes() == linear_numel * 2 * 2
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_cur_vram_bytes(device: str):
|
||||
model = DummyModule()
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
cached_model.model.to(device=torch.device(device))
|
||||
assert cached_model.cur_vram_bytes() == cached_model.total_bytes()
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_partial_load(device: str):
|
||||
model = DummyModule()
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
target_vram_bytes = int(model_total_bytes * 0.6)
|
||||
loaded_bytes = cached_model.partial_load_to_vram(target_vram_bytes)
|
||||
assert loaded_bytes > 0
|
||||
assert loaded_bytes < model_total_bytes
|
||||
assert loaded_bytes == cached_model.cur_vram_bytes()
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_partial_unload(device: str):
|
||||
model = DummyModule()
|
||||
model.to(device=torch.device(device))
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
assert cached_model.cur_vram_bytes() == model_total_bytes
|
||||
|
||||
bytes_to_free = int(model_total_bytes * 0.4)
|
||||
freed_bytes = cached_model.partial_unload_from_vram(bytes_to_free)
|
||||
assert freed_bytes >= bytes_to_free
|
||||
assert freed_bytes < model_total_bytes
|
||||
assert freed_bytes == model_total_bytes - cached_model.cur_vram_bytes()
|
||||
@@ -25,7 +25,7 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelVariantType,
|
||||
VAEDiffusersConfig,
|
||||
)
|
||||
from invokeai.backend.model_manager.load import ModelCache
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from tests.backend.model_manager.model_metadata.metadata_examples import (
|
||||
HFTestLoraMetadata,
|
||||
|
||||
Reference in New Issue
Block a user