Compare commits

...

33 Commits

Author SHA1 Message Date
psychedelicious
ac1071a5e5 chore: v4.1.0 2024-04-18 07:19:22 +10:00
Riccardo Giovanetti
5295a398f3 translationBot(ui): update translation (Italian)
Currently translated at 98.4% (1122 of 1140 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2024-04-17 08:41:57 +10:00
Anonymous
0c7283c82d translationBot(ui): update translation (Turkish)
Currently translated at 50.8% (580 of 1140 strings)

translationBot(ui): update translation (Korean)

Currently translated at 43.3% (494 of 1140 strings)

translationBot(ui): update translation (Chinese (Simplified))

Currently translated at 80.9% (923 of 1140 strings)

translationBot(ui): update translation (Russian)

Currently translated at 98.8% (1127 of 1140 strings)

translationBot(ui): update translation (Dutch)

Currently translated at 63.7% (727 of 1140 strings)

translationBot(ui): update translation (Japanese)

Currently translated at 50.4% (575 of 1140 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.3% (1121 of 1140 strings)

translationBot(ui): update translation (Spanish)

Currently translated at 27.8% (317 of 1140 strings)

translationBot(ui): update translation (German)

Currently translated at 72.2% (824 of 1140 strings)

Co-authored-by: Anonymous <noreply@weblate.org>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/de/
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/es/
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/ja/
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/ko/
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/nl/
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/ru/
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/tr/
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/zh_Hans/
Translation: InvokeAI/Web UI
2024-04-17 08:41:57 +10:00
Mary Hipp
73ad173c74 update labels for Style Only and CompositionOnly to be designated as beta 2024-04-17 08:29:10 +10:00
Kent Keirsey
c828a4e59f Add IP Adapter Style & Composition Modes (#6213)
## Summary

Until now IP Adapter had complete control on the contents of the output.
With this PR, users are now able to select "Style Only" or "Composition
Only" to draw just the style or layout of the reference image.

Based off: https://arxiv.org/abs/2404.02733

### New IP Method Option

- `Full` - Both style and layout of the refence image are used.
- `Style Only` - Only the style of the image is used
- `Composition Only` - Only the composition of the image is used.


![opera_0BkqZTwObO](https://github.com/invoke-ai/InvokeAI/assets/54517381/1b2fbbba-44c9-4c25-87cb-3711a17d13e3)

### Example Result


![demo](https://github.com/invoke-ai/InvokeAI/assets/54517381/703f3de5-e685-4691-acda-9338a4c10796)

### Notes

- Supports both SDXL and SD1.5

### Testing

- Just check and test if it works as expected with all IP Adapter models
- both SDXL and SD1.5

## Merge Plan

Good to merge once tested for all edge cases.
2024-04-16 14:23:36 -04:00
blessedcoolant
6bab040d24 Merge branch 'main' into ip-adapter-style-comp 2024-04-16 21:14:06 +05:30
blessedcoolant
f46bbaf8c4 fix: make ip-adapter weights not be optional 2024-04-16 21:12:45 +05:30
Lincoln Stein
fce6b3e44c maybe solve race issue 2024-04-16 13:09:26 +10:00
blessedcoolant
d27907cc6d fix: entire reshaping block needs to be skipped 2024-04-16 04:29:53 +05:30
blessedcoolant
7ee3fef2db cleanup: better var names for the ip adapter weight collection block 2024-04-16 04:23:50 +05:30
blessedcoolant
b39ce642b6 cleanup: raise ValueErrors when target_blocks dont match base model 2024-04-16 04:12:30 +05:30
blessedcoolant
a148c4322c fix: IP Adapter weights being incorrectly applied
They were being overwritten rather than being appended
2024-04-16 04:10:41 +05:30
blessedcoolant
f6b7bc5d98 fix: Dynamically adapt height of control adapter opts 2024-04-16 01:18:43 +05:30
blessedcoolant
5f6c6abf9c chore: change IPAdapterAttentionWeights to a dataclass 2024-04-15 23:38:55 +05:30
blessedcoolant
cd76a31a8f fix: IP Adapter method not being recalled 2024-04-15 22:29:32 +05:30
Lincoln Stein
e93f4d632d [util] Add generic torch device class (#6174)
* introduce new abstraction layer for GPU devices

* add unit test for device abstraction

* fix ruff

* convert TorchDeviceSelect into a stateless class

* move logic to select context-specific execution device into context API

* add mock hardware environments to pytest

* remove dangling mocker fixture

* fix unit test for running on non-CUDA systems

* remove unimplemented get_execution_device() call

* remove autocast precision

* Multiple changes:

1. Remove TorchDeviceSelect.get_execution_device(), as well as calls to
   context.models.get_execution_device().
2. Rename TorchDeviceSelect to TorchDevice
3. Added back the legacy public API defined in `invocation_api`, including
   choose_precision().
4. Added a config file migration script to accommodate removal of precision=autocast.

* add deprecation warnings to choose_torch_device() and choose_precision()

* fix test crash

* remove app_config argument from choose_torch_device() and choose_torch_dtype()

---------

Co-authored-by: Lincoln Stein <lstein@gmail.com>
2024-04-15 13:12:49 +00:00
psychedelicious
5a8489bbfc perf(ui): memoize infill components 2024-04-15 22:50:54 +10:00
psychedelicious
a24c9d0f7a perf(ui): optimize useFeatureStatus 2024-04-15 22:50:54 +10:00
psychedelicious
7a92afc117 perf(ui): fix rerenders in nodes
Unmemoized selector tanking perf
2024-04-15 22:50:54 +10:00
psychedelicious
b508945b11 feat(ui): edge labels
Add setting to render labels with format `Source Node label -> Target Node label` on edges.
2024-04-15 22:48:46 +10:00
blessedcoolant
8426f1e7b2 fix(experimental): Possible fix for conflict with regional embed length mismatch
Pushing this so people can test it out and see if this needs to be handled in a different way.
2024-04-14 12:19:19 +05:30
blessedcoolant
9cb0f63c44 refactor: fix a bunch of type issues in custom_attention 2024-04-13 14:17:25 +05:30
blessedcoolant
2d5786d3bb fix: Incorrect composition blocks for SD1.5 2024-04-13 13:52:10 +05:30
blessedcoolant
27466ffa1a chore: update the ip adapter node version 2024-04-13 13:39:08 +05:30
blessedcoolant
f50b156511 chore: do not include custom nodes in schema 2024-04-13 12:43:49 +05:30
blessedcoolant
9fc73743b2 feat: support SD1.5 2024-04-13 12:30:39 +05:30
blessedcoolant
d4393e4170 chore: linter fixes 2024-04-13 12:14:45 +05:30
blessedcoolant
145a0b029e Merge branch 'ip-adapter-style-comp' of https://github.com/blessedcoolant/InvokeAI into ip-adapter-style-comp 2024-04-13 12:13:06 +05:30
blessedcoolant
f2506cc769 chore: ruff fixes
Revert "chore: ruff fixes"

This reverts commit af36fe8c1e.

Revert "chore: ruff fixes"

This reverts commit af36fe8c1e.
2024-04-13 12:12:33 +05:30
blessedcoolant
7a67fd6a06 Revert "chore: ruff fixes"
This reverts commit af36fe8c1e.
2024-04-13 12:10:20 +05:30
blessedcoolant
af36fe8c1e chore: ruff fixes 2024-04-13 12:08:52 +05:30
blessedcoolant
e9f16ac8c7 feat: add UI for IP Adapter Method 2024-04-13 12:06:59 +05:30
blessedcoolant
6ea183f0d4 wip: Initial Implementation IP Adapter Style & Comp Modes 2024-04-13 11:09:45 +05:30
73 changed files with 875 additions and 358 deletions

View File

@@ -28,7 +28,7 @@ from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.backend.util.devices import get_torch_device_name
from invokeai.backend.util.devices import TorchDevice
from ..backend.util.logging import InvokeAILogger
from .api.dependencies import ApiDependencies
@@ -63,7 +63,7 @@ logger = InvokeAILogger.get_logger(config=app_config)
mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("text/css", ".css")
torch_device_name = get_torch_device_name()
torch_device_name = TorchDevice.get_torch_device_name()
logger.info(f"Using torch device: {torch_device_name}")

View File

@@ -24,7 +24,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
ConditioningFieldData,
SDXLConditioningInfo,
)
from invokeai.backend.util.devices import torch_dtype
from invokeai.backend.util.devices import TorchDevice
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from .model import CLIPField
@@ -99,7 +99,7 @@ class CompelInvocation(BaseInvocation):
tokenizer=tokenizer,
text_encoder=text_encoder,
textual_inversion_manager=ti_manager,
dtype_for_device_getter=torch_dtype,
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
truncate_long_prompts=False,
)
@@ -193,7 +193,7 @@ class SDXLPromptInvocationBase:
tokenizer=tokenizer,
text_encoder=text_encoder,
textual_inversion_manager=ti_manager,
dtype_for_device_getter=torch_dtype,
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
truncate_long_prompts=False, # TODO:
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
requires_pooled=get_pooled,

View File

@@ -4,20 +4,8 @@ from typing import List, Literal, Optional, Union
from pydantic import BaseModel, Field, field_validator, model_validator
from typing_extensions import Self
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
OutputField,
TensorField,
UIType,
)
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, TensorField, UIType
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
@@ -36,6 +24,7 @@ class IPAdapterField(BaseModel):
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model to use.")
image_encoder_model: ModelIdentifierField = Field(description="The name of the CLIP image encoder model.")
weight: Union[float, List[float]] = Field(default=1, description="The weight given to the IP-Adapter.")
target_blocks: List[str] = Field(default=[], description="The IP Adapter blocks to apply")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
)
@@ -69,7 +58,7 @@ class IPAdapterOutput(BaseInvocationOutput):
CLIP_VISION_MODEL_MAP = {"ViT-H": "ip_adapter_sd_image_encoder", "ViT-G": "ip_adapter_sdxl_image_encoder"}
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.3.0")
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.4.0")
class IPAdapterInvocation(BaseInvocation):
"""Collects IP-Adapter info to pass to other nodes."""
@@ -90,6 +79,9 @@ class IPAdapterInvocation(BaseInvocation):
weight: Union[float, List[float]] = InputField(
default=1, description="The weight given to the IP-Adapter", title="Weight"
)
method: Literal["full", "style", "composition"] = InputField(
default="full", description="The method to apply the IP-Adapter"
)
begin_step_percent: float = InputField(
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
)
@@ -124,12 +116,32 @@ class IPAdapterInvocation(BaseInvocation):
image_encoder_model = self._get_image_encoder(context, image_encoder_model_name)
if self.method == "style":
if ip_adapter_info.base == "sd-1":
target_blocks = ["up_blocks.1"]
elif ip_adapter_info.base == "sdxl":
target_blocks = ["up_blocks.0.attentions.1"]
else:
raise ValueError(f"Unsupported IP-Adapter base type: '{ip_adapter_info.base}'.")
elif self.method == "composition":
if ip_adapter_info.base == "sd-1":
target_blocks = ["down_blocks.2", "mid_block"]
elif ip_adapter_info.base == "sdxl":
target_blocks = ["down_blocks.2.attentions.1"]
else:
raise ValueError(f"Unsupported IP-Adapter base type: '{ip_adapter_info.base}'.")
elif self.method == "full":
target_blocks = ["block"]
else:
raise ValueError(f"Unexpected IP-Adapter method: '{self.method}'.")
return IPAdapterOutput(
ip_adapter=IPAdapterField(
image=self.image,
ip_adapter_model=self.ip_adapter_model,
image_encoder_model=ModelIdentifierField.from_config(image_encoder_model),
weight=self.weight,
target_blocks=target_blocks,
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,
mask=self.mask,

View File

@@ -72,15 +72,12 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
image_resized_to_grid_as_tensor,
)
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.util.devices import choose_precision, choose_torch_device
from ...backend.util.devices import TorchDevice
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from .controlnet_image_processors import ControlField
from .model import ModelIdentifierField, UNetField, VAEField
if choose_torch_device() == torch.device("mps"):
from torch import mps
DEFAULT_PRECISION = choose_precision(choose_torch_device())
DEFAULT_PRECISION = TorchDevice.choose_torch_dtype()
@invocation_output("scheduler_output")
@@ -682,6 +679,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
IPAdapterData(
ip_adapter_model=ip_adapter_model,
weight=single_ip_adapter.weight,
target_blocks=single_ip_adapter.target_blocks,
begin_step_percent=single_ip_adapter.begin_step_percent,
end_step_percent=single_ip_adapter.end_step_percent,
ip_adapter_conditioning=IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds),
@@ -959,9 +957,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.to("cpu")
torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
TorchDevice.empty_cache()
name = context.tensors.save(tensor=result_latents)
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)
@@ -1028,9 +1024,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
vae.disable_tiling()
# clear memory as vae decode can request a lot
torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
TorchDevice.empty_cache()
with torch.inference_mode():
# copied from diffusers pipeline
@@ -1042,9 +1036,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
image = VaeImageProcessor.numpy_to_pil(np_image)[0]
torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
TorchDevice.empty_cache()
image_dto = context.images.save(image=image)
@@ -1083,9 +1075,7 @@ class ResizeLatentsInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.tensors.load(self.latents.latents_name)
# TODO:
device = choose_torch_device()
device = TorchDevice.choose_torch_device()
resized_latents = torch.nn.functional.interpolate(
latents.to(device),
@@ -1096,9 +1086,8 @@ class ResizeLatentsInvocation(BaseInvocation):
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
resized_latents = resized_latents.to("cpu")
torch.cuda.empty_cache()
if device == torch.device("mps"):
mps.empty_cache()
TorchDevice.empty_cache()
name = context.tensors.save(tensor=resized_latents)
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
@@ -1125,8 +1114,7 @@ class ScaleLatentsInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.tensors.load(self.latents.latents_name)
# TODO:
device = choose_torch_device()
device = TorchDevice.choose_torch_device()
# resizing
resized_latents = torch.nn.functional.interpolate(
@@ -1138,9 +1126,7 @@ class ScaleLatentsInvocation(BaseInvocation):
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
resized_latents = resized_latents.to("cpu")
torch.cuda.empty_cache()
if device == torch.device("mps"):
mps.empty_cache()
TorchDevice.empty_cache()
name = context.tensors.save(tensor=resized_latents)
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
@@ -1272,8 +1258,7 @@ class BlendLatentsInvocation(BaseInvocation):
if latents_a.shape != latents_b.shape:
raise Exception("Latents to blend must be the same size.")
# TODO:
device = choose_torch_device()
device = TorchDevice.choose_torch_device()
def slerp(
t: Union[float, npt.NDArray[Any]], # FIXME: maybe use np.float32 here?
@@ -1326,9 +1311,8 @@ class BlendLatentsInvocation(BaseInvocation):
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
blended_latents = blended_latents.to("cpu")
torch.cuda.empty_cache()
if device == torch.device("mps"):
mps.empty_cache()
TorchDevice.empty_cache()
name = context.tensors.save(tensor=blended_latents)
return LatentsOutput.build(latents_name=name, latents=blended_latents)

View File

@@ -36,6 +36,7 @@ class IPAdapterMetadataField(BaseModel):
image: ImageField = Field(description="The IP-Adapter image prompt.")
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model.")
clip_vision_model: Literal["ViT-H", "ViT-G"] = Field(description="The CLIP Vision model")
method: Literal["full", "style", "composition"] = Field(description="Method to apply IP Weights with")
weight: Union[float, list[float]] = Field(description="The weight given to the IP-Adapter")
begin_step_percent: float = Field(description="When the IP-Adapter is first applied (% of total steps)")
end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)")

View File

@@ -9,7 +9,7 @@ from invokeai.app.invocations.fields import FieldDescriptions, InputField, Laten
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.misc import SEED_MAX
from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.util.devices import TorchDevice
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
@@ -46,7 +46,7 @@ def get_noise(
height // downsampling_factor,
width // downsampling_factor,
],
dtype=torch_dtype(device),
dtype=TorchDevice.choose_torch_dtype(device=device),
device=noise_device_type,
generator=generator,
).to("cpu")
@@ -111,14 +111,14 @@ class NoiseInvocation(BaseInvocation):
@field_validator("seed", mode="before")
def modulo_seed(cls, v):
"""Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
"""Return the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
return v % (SEED_MAX + 1)
def invoke(self, context: InvocationContext) -> NoiseOutput:
noise = get_noise(
width=self.width,
height=self.height,
device=choose_torch_device(),
device=TorchDevice.choose_torch_device(),
seed=self.seed,
use_cpu=self.use_cpu,
)

View File

@@ -4,7 +4,6 @@ from typing import Literal
import cv2
import numpy as np
import torch
from PIL import Image
from pydantic import ConfigDict
@@ -14,7 +13,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.devices import TorchDevice
from .baseinvocation import BaseInvocation, invocation
from .fields import InputField, WithBoard, WithMetadata
@@ -35,9 +34,6 @@ ESRGAN_MODEL_URLS: dict[str, str] = {
"RealESRGAN_x2plus.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
}
if choose_torch_device() == torch.device("mps"):
from torch import mps
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.2")
class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
@@ -120,9 +116,7 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
upscaled_image = upscaler.upscale(cv2_image)
pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA")
torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
TorchDevice.empty_cache()
image_dto = context.images.save(image=pil_image)

View File

@@ -27,12 +27,12 @@ DEFAULT_RAM_CACHE = 10.0
DEFAULT_VRAM_CACHE = 0.25
DEFAULT_CONVERT_CACHE = 20.0
DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"]
PRECISION = Literal["auto", "float16", "bfloat16", "float32", "autocast"]
PRECISION = Literal["auto", "float16", "bfloat16", "float32"]
ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"]
LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
CONFIG_SCHEMA_VERSION = "4.0.0"
CONFIG_SCHEMA_VERSION = "4.0.1"
def get_default_ram_cache_size() -> float:
@@ -105,7 +105,7 @@ class InvokeAIAppConfig(BaseSettings):
lazy_offload: Keep models in VRAM until their space is needed.
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`, `autocast`
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
attention_type: Attention type.<br>Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp`
attention_slice_size: Slice size, valid when attention_type=="sliced".<br>Valid values: `auto`, `balanced`, `max`, `1`, `2`, `3`, `4`, `5`, `6`, `7`, `8`
@@ -370,6 +370,9 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
if k == "max_vram_cache_size" and "vram" not in category_dict:
parsed_config_dict["vram"] = v
# autocast was removed in v4.0.1
if k == "precision" and v == "autocast":
parsed_config_dict["precision"] = "auto"
if k == "conf_path":
parsed_config_dict["legacy_models_yaml_path"] = v
if k == "legacy_conf_dir":
@@ -392,6 +395,28 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
return config
def migrate_v4_0_0_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
"""Migrate v4.0.0 config dictionary to a current config object.
Args:
config_dict: A dictionary of settings from a v4.0.0 config file.
Returns:
An instance of `InvokeAIAppConfig` with the migrated settings.
"""
parsed_config_dict: dict[str, Any] = {}
for k, v in config_dict.items():
# autocast was removed from precision in v4.0.1
if k == "precision" and v == "autocast":
parsed_config_dict["precision"] = "auto"
else:
parsed_config_dict[k] = v
if k == "schema_version":
parsed_config_dict[k] = CONFIG_SCHEMA_VERSION
config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
return config
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
"""Load and migrate a config file to the latest version.
@@ -418,17 +443,21 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
raise RuntimeError(f"Failed to load and migrate v3 config file {config_path}: {e}") from e
migrated_config.write_file(config_path)
return migrated_config
else:
# Attempt to load as a v4 config file
try:
# Meta is not included in the model fields, so we need to validate it separately
config = InvokeAIAppConfig.model_validate(loaded_config_dict)
assert (
config.schema_version == CONFIG_SCHEMA_VERSION
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}"
return config
except Exception as e:
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e
if loaded_config_dict["schema_version"] == "4.0.0":
loaded_config_dict = migrate_v4_0_0_config_dict(loaded_config_dict)
loaded_config_dict.write_file(config_path)
# Attempt to load as a v4 config file
try:
# Meta is not included in the model fields, so we need to validate it separately
config = InvokeAIAppConfig.model_validate(loaded_config_dict)
assert (
config.schema_version == CONFIG_SCHEMA_VERSION
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}"
return config
except Exception as e:
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e
@lru_cache(maxsize=1)

View File

@@ -13,6 +13,7 @@ from shutil import copyfile, copytree, move, rmtree
from tempfile import mkdtemp
from typing import Any, Dict, List, Optional, Union
import torch
import yaml
from huggingface_hub import HfFolder
from pydantic.networks import AnyHttpUrl
@@ -42,7 +43,7 @@ from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMet
from invokeai.backend.model_manager.probe import ModelProbe
from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.util import InvokeAILogger
from invokeai.backend.util.devices import choose_precision, choose_torch_device
from invokeai.backend.util.devices import TorchDevice
from .model_install_base import (
MODEL_SOURCE_TO_TYPE_MAP,
@@ -634,11 +635,10 @@ class ModelInstallService(ModelInstallServiceBase):
self._next_job_id += 1
return id
@staticmethod
def _guess_variant() -> Optional[ModelRepoVariant]:
def _guess_variant(self) -> Optional[ModelRepoVariant]:
"""Guess the best HuggingFace variant type to download."""
precision = choose_precision(choose_torch_device())
return ModelRepoVariant.FP16 if precision == "float16" else None
precision = TorchDevice.choose_torch_dtype()
return ModelRepoVariant.FP16 if precision == torch.float16 else None
def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
return ModelInstallJob(
@@ -754,6 +754,8 @@ class ModelInstallService(ModelInstallServiceBase):
self._download_cache[download_job.source] = install_job # matches a download job to an install job
install_job.download_parts.add(download_job)
# only start the jobs once install_job.download_parts is fully populated
for download_job in install_job.download_parts:
self._download_queue.submit_download_job(
download_job,
on_start=self._download_started_callback,
@@ -762,6 +764,7 @@ class ModelInstallService(ModelInstallServiceBase):
on_error=self._download_error_callback,
on_cancelled=self._download_cancelled_callback,
)
return install_job
def _stat_size(self, path: Path) -> int:

View File

@@ -1,12 +1,14 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
"""Implementation of ModelManagerServiceBase."""
from typing import Optional
import torch
from typing_extensions import Self
from invokeai.app.services.invoker import Invoker
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
from ..config import InvokeAIAppConfig
@@ -67,7 +69,7 @@ class ModelManagerService(ModelManagerServiceBase):
model_record_service: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase,
events: EventServiceBase,
execution_device: torch.device = choose_torch_device(),
execution_device: Optional[torch.device] = None,
) -> Self:
"""
Construct the model manager service instance.
@@ -82,7 +84,7 @@ class ModelManagerService(ModelManagerServiceBase):
max_vram_cache_size=app_config.vram,
lazy_offloading=app_config.lazy_offload,
logger=logger,
execution_device=execution_device,
execution_device=execution_device or TorchDevice.choose_torch_device(),
)
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)
loader = ModelLoadService(

View File

@@ -13,7 +13,7 @@ from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
config = get_config()
@@ -56,7 +56,7 @@ class DepthAnythingDetector:
def __init__(self) -> None:
self.model = None
self.model_size: Union[Literal["large", "base", "small"], None] = None
self.device = choose_torch_device()
self.device = TorchDevice.choose_torch_device()
def load_model(self, model_size: Literal["large", "base", "small"] = "small"):
DEPTH_ANYTHING_MODEL_PATH = config.models_path / DEPTH_ANYTHING_MODELS[model_size]["local"]
@@ -81,7 +81,7 @@ class DepthAnythingDetector:
self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu"))
self.model.eval()
self.model.to(choose_torch_device())
self.model.to(self.device)
return self.model
def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image:
@@ -94,7 +94,7 @@ class DepthAnythingDetector:
image_height, image_width = np_image.shape[:2]
np_image = transform({"image": np_image})["image"]
tensor_image = torch.from_numpy(np_image).unsqueeze(0).to(choose_torch_device())
tensor_image = torch.from_numpy(np_image).unsqueeze(0).to(self.device)
with torch.no_grad():
depth = self.model(tensor_image)

View File

@@ -7,7 +7,7 @@ import onnxruntime as ort
from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.devices import TorchDevice
from .onnxdet import inference_detector
from .onnxpose import inference_pose
@@ -28,9 +28,9 @@ config = get_config()
class Wholebody:
def __init__(self):
device = choose_torch_device()
device = TorchDevice.choose_torch_device()
providers = ["CUDAExecutionProvider"] if device == "cuda" else ["CPUExecutionProvider"]
providers = ["CUDAExecutionProvider"] if device.type == "cuda" else ["CPUExecutionProvider"]
DET_MODEL_PATH = config.models_path / DWPOSE_MODELS["yolox_l.onnx"]["local"]
download_with_progress_bar("yolox_l.onnx", DWPOSE_MODELS["yolox_l.onnx"]["url"], DET_MODEL_PATH)

View File

@@ -8,7 +8,7 @@ from PIL import Image
import invokeai.backend.util.logging as logger
from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.devices import TorchDevice
def norm_img(np_img):
@@ -29,7 +29,7 @@ def load_jit_model(url_or_path, device):
class LaMA:
def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
device = choose_torch_device()
device = TorchDevice.choose_torch_device()
model_location = get_config().models_path / "core/misc/lama/lama.pt"
if not model_location.exists():

View File

@@ -11,7 +11,7 @@ from cv2.typing import MatLike
from tqdm import tqdm
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.devices import TorchDevice
"""
Adapted from https://github.com/xinntao/Real-ESRGAN/blob/master/realesrgan/utils.py
@@ -65,7 +65,7 @@ class RealESRGAN:
self.pre_pad = pre_pad
self.mod_scale: Optional[int] = None
self.half = half
self.device = choose_torch_device()
self.device = TorchDevice.choose_torch_device()
loadnet = torch.load(model_path, map_location=torch.device("cpu"))

View File

@@ -13,7 +13,7 @@ from transformers import AutoFeatureExtractor
import invokeai.backend.util.logging as logger
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.silence_warnings import SilenceWarnings
CHECKER_PATH = "core/convert/stable-diffusion-safety-checker"
@@ -51,7 +51,7 @@ class SafetyChecker:
cls._load_safety_checker()
if cls.safety_checker is None or cls.feature_extractor is None:
return False
device = choose_torch_device()
device = TorchDevice.choose_torch_device()
features = cls.feature_extractor([image], return_tensors="pt")
features.to(device)
cls.safety_checker.to(device)

View File

@@ -18,7 +18,7 @@ from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoad
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data, calc_model_size_by_fs
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
from invokeai.backend.util.devices import TorchDevice
# TO DO: The loader is not thread safe!
@@ -37,7 +37,7 @@ class ModelLoader(ModelLoaderBase):
self._logger = logger
self._ram_cache = ram_cache
self._convert_cache = convert_cache
self._torch_dtype = torch_dtype(choose_torch_device())
self._torch_dtype = TorchDevice.choose_torch_dtype()
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""

View File

@@ -30,15 +30,12 @@ import torch
from invokeai.backend.model_manager import AnyModel, SubModelType
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase, ModelLockerBase
from .model_locker import ModelLocker
if choose_torch_device() == torch.device("mps"):
from torch import mps
# Maximum size of the cache, in gigs
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
DEFAULT_MAX_CACHE_SIZE = 6.0
@@ -244,9 +241,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
)
torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
TorchDevice.empty_cache()
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
"""Move model into the indicated device.
@@ -416,10 +411,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
self.stats.cleared = models_cleared
gc.collect()
torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
TorchDevice.empty_cache()
self.logger.debug(f"After making room: cached_models={len(self._cached_models)}")
def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None:

View File

@@ -17,7 +17,7 @@ from diffusers.utils import logging as dlogging
from invokeai.app.services.model_install import ModelInstallServiceBase
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
from invokeai.backend.util.devices import TorchDevice
from . import (
AnyModelConfig,
@@ -43,6 +43,7 @@ class ModelMerger(object):
Initialize a ModelMerger object with the model installer.
"""
self._installer = installer
self._dtype = TorchDevice.choose_torch_dtype()
def merge_diffusion_models(
self,
@@ -68,7 +69,7 @@ class ModelMerger(object):
warnings.simplefilter("ignore")
verbosity = dlogging.get_verbosity()
dlogging.set_verbosity_error()
dtype = torch.float16 if variant == "fp16" else torch_dtype(choose_torch_device())
dtype = torch.float16 if variant == "fp16" else self._dtype
# Note that checkpoint_merger will not work with downloaded HuggingFace fp16 models
# until upstream https://github.com/huggingface/diffusers/pull/6670 is merged and released.
@@ -151,7 +152,7 @@ class ModelMerger(object):
dump_path.mkdir(parents=True, exist_ok=True)
dump_path = dump_path / merged_model_name
dtype = torch.float16 if variant == "fp16" else torch_dtype(choose_torch_device())
dtype = torch.float16 if variant == "fp16" else self._dtype
merged_pipe.save_pretrained(dump_path.as_posix(), safe_serialization=True, torch_dtype=dtype, variant=variant)
# register model and get its unique key

View File

@@ -21,14 +21,11 @@ from pydantic import Field
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
IPAdapterData,
TextConditioningData,
)
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
from invokeai.backend.util.attention import auto_detect_slice_size
from invokeai.backend.util.devices import normalize_device
from invokeai.backend.util.devices import TorchDevice
@dataclass
@@ -258,7 +255,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if self.unet.device.type == "cpu" or self.unet.device.type == "mps":
mem_free = psutil.virtual_memory().free
elif self.unet.device.type == "cuda":
mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.unet.device))
mem_free, _ = torch.cuda.mem_get_info(TorchDevice.normalize(self.unet.device))
else:
raise ValueError(f"unrecognized device {self.unet.device}")
# input tensor of [1, 4, h/8, w/8]
@@ -394,8 +391,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
unet_attention_patcher = None
self.use_ip_adapter = use_ip_adapter
attn_ctx = nullcontext()
if use_ip_adapter or use_regional_prompting:
ip_adapters = [ipa.ip_adapter_model for ipa in ip_adapter_data] if use_ip_adapter else None
ip_adapters: Optional[List[UNetIPAdapterData]] = (
[{"ip_adapter": ipa.ip_adapter_model, "target_blocks": ipa.target_blocks} for ipa in ip_adapter_data]
if use_ip_adapter
else None
)
unet_attention_patcher = UNetAttentionPatcher(ip_adapters)
attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)

View File

@@ -53,6 +53,7 @@ class IPAdapterData:
ip_adapter_model: IPAdapter
ip_adapter_conditioning: IPAdapterConditioningInfo
mask: torch.Tensor
target_blocks: List[str]
# Either a single weight applied to all steps, or a list of weights for each step.
weight: Union[float, List[float]] = 1.0

View File

@@ -1,4 +1,5 @@
from typing import Optional
from dataclasses import dataclass
from typing import List, Optional, cast
import torch
import torch.nn.functional as F
@@ -9,6 +10,12 @@ from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import Regiona
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
@dataclass
class IPAdapterAttentionWeights:
ip_adapter_weights: IPAttentionProcessorWeights
skip: bool
class CustomAttnProcessor2_0(AttnProcessor2_0):
"""A custom implementation of AttnProcessor2_0 that supports additional Invoke features.
This implementation is based on
@@ -20,7 +27,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
def __init__(
self,
ip_adapter_weights: Optional[list[IPAttentionProcessorWeights]] = None,
ip_adapter_attention_weights: Optional[List[IPAdapterAttentionWeights]] = None,
):
"""Initialize a CustomAttnProcessor2_0.
Note: Arguments that are the same for all attention layers are passed to __call__(). Arguments that are
@@ -30,23 +37,22 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
for the i'th IP-Adapter.
"""
super().__init__()
self._ip_adapter_weights = ip_adapter_weights
def _is_ip_adapter_enabled(self) -> bool:
return self._ip_adapter_weights is not None
self._ip_adapter_attention_weights = ip_adapter_attention_weights
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
# For regional prompting:
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
# For Regional Prompting:
regional_prompt_data: Optional[RegionalPromptData] = None,
percent_through: Optional[torch.FloatTensor] = None,
percent_through: Optional[torch.Tensor] = None,
# For IP-Adapter:
regional_ip_data: Optional[RegionalIPData] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
"""Apply attention.
Args:
@@ -130,17 +136,19 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
# Apply IP-Adapter conditioning.
if is_cross_attention:
if self._is_ip_adapter_enabled():
if self._ip_adapter_attention_weights:
assert regional_ip_data is not None
ip_masks = regional_ip_data.get_masks(query_seq_len=query_seq_len)
assert (
len(regional_ip_data.image_prompt_embeds)
== len(self._ip_adapter_weights)
== len(self._ip_adapter_attention_weights)
== len(regional_ip_data.scales)
== ip_masks.shape[1]
)
for ipa_index, ipa_embed in enumerate(regional_ip_data.image_prompt_embeds):
ipa_weights = self._ip_adapter_weights[ipa_index]
ipa_weights = self._ip_adapter_attention_weights[ipa_index].ip_adapter_weights
ipa_scale = regional_ip_data.scales[ipa_index]
ip_mask = ip_masks[0, ipa_index, ...]
@@ -153,29 +161,33 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
if not self._ip_adapter_attention_weights[ipa_index].skip:
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
# Expected ip_key and ip_value shape: (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
# Expected ip_key and ip_value shape:
# (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# Expected ip_key and ip_value shape: (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
# Expected ip_key and ip_value shape:
# (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)
ip_hidden_states = ip_hidden_states.to(query.dtype)
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask
else:
# If IP-Adapter is not enabled, then regional_ip_data should not be passed in.
assert regional_ip_data is None
@@ -188,11 +200,15 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# End of unmodified block from AttnProcessor2_0
return hidden_states
# casting torch.Tensor to torch.FloatTensor to avoid type issues
return cast(torch.FloatTensor, hidden_states)

View File

@@ -1,17 +1,25 @@
from contextlib import contextmanager
from typing import Optional
from typing import List, Optional, TypedDict
from diffusers.models import UNet2DConditionModel
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import (
CustomAttnProcessor2_0,
IPAdapterAttentionWeights,
)
class UNetIPAdapterData(TypedDict):
ip_adapter: IPAdapter
target_blocks: List[str]
class UNetAttentionPatcher:
"""A class for patching a UNet with CustomAttnProcessor2_0 attention layers."""
def __init__(self, ip_adapters: Optional[list[IPAdapter]]):
self._ip_adapters = ip_adapters
def __init__(self, ip_adapter_data: Optional[List[UNetIPAdapterData]]):
self._ip_adapters = ip_adapter_data
def _prepare_attention_processors(self, unet: UNet2DConditionModel):
"""Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention
@@ -26,9 +34,22 @@ class UNetAttentionPatcher:
attn_procs[name] = CustomAttnProcessor2_0()
else:
# Collect the weights from each IP Adapter for the idx'th attention processor.
attn_procs[name] = CustomAttnProcessor2_0(
[ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in self._ip_adapters],
)
ip_adapter_attention_weights_collection: list[IPAdapterAttentionWeights] = []
for ip_adapter in self._ip_adapters:
ip_adapter_weights = ip_adapter["ip_adapter"].attn_weights.get_attention_processor_weights(idx)
skip = True
for block in ip_adapter["target_blocks"]:
if block in name:
skip = False
break
ip_adapter_attention_weights: IPAdapterAttentionWeights = IPAdapterAttentionWeights(
ip_adapter_weights=ip_adapter_weights, skip=skip
)
ip_adapter_attention_weights_collection.append(ip_adapter_attention_weights)
attn_procs[name] = CustomAttnProcessor2_0(ip_adapter_attention_weights_collection)
return attn_procs
@contextmanager

View File

@@ -2,7 +2,6 @@
Initialization file for invokeai.backend.util
"""
from .devices import choose_precision, choose_torch_device
from .logging import InvokeAILogger
from .util import GIG, Chdir, directory_size
@@ -11,6 +10,4 @@ __all__ = [
"directory_size",
"Chdir",
"InvokeAILogger",
"choose_precision",
"choose_torch_device",
]

View File

@@ -1,89 +1,110 @@
from __future__ import annotations
from contextlib import nullcontext
from typing import Literal, Optional, Union
from typing import Dict, Literal, Optional, Union
import torch
from torch import autocast
from deprecated import deprecated
from invokeai.app.services.config.config_default import PRECISION, get_config
from invokeai.app.services.config.config_default import get_config
# legacy APIs
TorchPrecisionNames = Literal["float32", "float16", "bfloat16"]
CPU_DEVICE = torch.device("cpu")
CUDA_DEVICE = torch.device("cuda")
MPS_DEVICE = torch.device("mps")
@deprecated("Use TorchDevice.choose_torch_dtype() instead.") # type: ignore
def choose_precision(device: torch.device) -> TorchPrecisionNames:
"""Return the string representation of the recommended torch device."""
torch_dtype = TorchDevice.choose_torch_dtype(device)
return PRECISION_TO_NAME[torch_dtype]
@deprecated("Use TorchDevice.choose_torch_device() instead.") # type: ignore
def choose_torch_device() -> torch.device:
"""Convenience routine for guessing which GPU device to run model on"""
config = get_config()
if config.device == "auto":
if torch.cuda.is_available():
return torch.device("cuda")
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return torch.device("mps")
"""Return the torch.device to use for accelerated inference."""
return TorchDevice.choose_torch_device()
@deprecated("Use TorchDevice.choose_torch_dtype() instead.") # type: ignore
def torch_dtype(device: torch.device) -> torch.dtype:
"""Return the torch precision for the recommended torch device."""
return TorchDevice.choose_torch_dtype(device)
NAME_TO_PRECISION: Dict[TorchPrecisionNames, torch.dtype] = {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}
PRECISION_TO_NAME: Dict[torch.dtype, TorchPrecisionNames] = {v: k for k, v in NAME_TO_PRECISION.items()}
class TorchDevice:
"""Abstraction layer for torch devices."""
@classmethod
def choose_torch_device(cls) -> torch.device:
"""Return the torch.device to use for accelerated inference."""
app_config = get_config()
if app_config.device != "auto":
device = torch.device(app_config.device)
elif torch.cuda.is_available():
device = CUDA_DEVICE
elif torch.backends.mps.is_available():
device = MPS_DEVICE
else:
return CPU_DEVICE
else:
return torch.device(config.device)
device = CPU_DEVICE
return cls.normalize(device)
@classmethod
def choose_torch_dtype(cls, device: Optional[torch.device] = None) -> torch.dtype:
"""Return the precision to use for accelerated inference."""
device = device or cls.choose_torch_device()
config = get_config()
if device.type == "cuda" and torch.cuda.is_available():
device_name = torch.cuda.get_device_name(device)
if "GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name:
# These GPUs have limited support for float16
return cls._to_dtype("float32")
elif config.precision == "auto":
# Default to float16 for CUDA devices
return cls._to_dtype("float16")
else:
# Use the user-defined precision
return cls._to_dtype(config.precision)
def get_torch_device_name() -> str:
device = choose_torch_device()
return torch.cuda.get_device_name(device) if device.type == "cuda" else device.type.upper()
elif device.type == "mps" and torch.backends.mps.is_available():
if config.precision == "auto":
# Default to float16 for MPS devices
return cls._to_dtype("float16")
else:
# Use the user-defined precision
return cls._to_dtype(config.precision)
# CPU / safe fallback
return cls._to_dtype("float32")
@classmethod
def get_torch_device_name(cls) -> str:
"""Return the device name for the current torch device."""
device = cls.choose_torch_device()
return torch.cuda.get_device_name(device) if device.type == "cuda" else device.type.upper()
def choose_precision(device: torch.device) -> Literal["float32", "float16", "bfloat16"]:
"""Return an appropriate precision for the given torch device."""
app_config = get_config()
if device.type == "cuda":
device_name = torch.cuda.get_device_name(device)
if "GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name:
# These GPUs have limited support for float16
return "float32"
elif app_config.precision == "auto" or app_config.precision == "autocast":
# Default to float16 for CUDA devices
return "float16"
else:
# Use the user-defined precision
return app_config.precision
elif device.type == "mps":
if app_config.precision == "auto" or app_config.precision == "autocast":
# Default to float16 for MPS devices
return "float16"
else:
# Use the user-defined precision
return app_config.precision
# CPU / safe fallback
return "float32"
def torch_dtype(device: Optional[torch.device] = None) -> torch.dtype:
device = device or choose_torch_device()
precision = choose_precision(device)
if precision == "float16":
return torch.float16
if precision == "bfloat16":
return torch.bfloat16
else:
# "auto", "autocast", "float32"
return torch.float32
def choose_autocast(precision: PRECISION):
"""Returns an autocast context or nullcontext for the given precision string"""
# float16 currently requires autocast to avoid errors like:
# 'expected scalar type Half but found Float'
if precision == "autocast" or precision == "float16":
return autocast
return nullcontext
def normalize_device(device: Union[str, torch.device]) -> torch.device:
"""Ensure device has a device index defined, if appropriate."""
device = torch.device(device)
if device.index is None:
# cuda might be the only torch backend that currently uses the device index?
# I don't see anything like `current_device` for cpu or mps.
if device.type == "cuda":
@classmethod
def normalize(cls, device: Union[str, torch.device]) -> torch.device:
"""Add the device index to CUDA devices."""
device = torch.device(device)
if device.index is None and device.type == "cuda" and torch.cuda.is_available():
device = torch.device(device.type, torch.cuda.current_device())
return device
return device
@classmethod
def empty_cache(cls) -> None:
"""Clear the GPU device cache."""
if torch.backends.mps.is_available():
torch.mps.empty_cache()
if torch.cuda.is_available():
torch.cuda.empty_cache()
@classmethod
def _to_dtype(cls, precision_name: TorchPrecisionNames) -> torch.dtype:
return NAME_TO_PRECISION[precision_name]

View File

@@ -85,7 +85,8 @@
"loadMore": "Mehr laden",
"noImagesInGallery": "Keine Bilder in der Galerie",
"loading": "Lade",
"deleteImage": "Lösche Bild",
"deleteImage_one": "Lösche Bild",
"deleteImage_other": "",
"copy": "Kopieren",
"download": "Runterladen",
"setCurrentImage": "Setze aktuelle Bild",

View File

@@ -69,6 +69,7 @@
"auto": "Auto",
"back": "Back",
"batch": "Batch Manager",
"beta": "Beta",
"cancel": "Cancel",
"copy": "Copy",
"copyError": "$t(gallery.copy) Error",
@@ -213,6 +214,10 @@
"resize": "Resize",
"resizeSimple": "Resize (Simple)",
"resizeMode": "Resize Mode",
"ipAdapterMethod": "Method",
"full": "Full",
"style": "Style Only",
"composition": "Composition Only",
"safe": "Safe",
"saveControlImage": "Save Control Image",
"scribble": "scribble",
@@ -770,6 +775,8 @@
"float": "Float",
"fullyContainNodes": "Fully Contain Nodes to Select",
"fullyContainNodesHelp": "Nodes must be fully inside the selection box to be selected",
"showEdgeLabels": "Show Edge Labels",
"showEdgeLabelsHelp": "Show labels on edges, indicating the connected nodes",
"hideLegendNodes": "Hide Field Type Legend",
"hideMinimapnodes": "Hide MiniMap",
"inputMayOnlyHaveOneConnection": "Input may only have one connection",

View File

@@ -33,7 +33,9 @@
"autoSwitchNewImages": "Auto seleccionar Imágenes nuevas",
"loadMore": "Cargar más",
"noImagesInGallery": "No hay imágenes para mostrar",
"deleteImage": "Eliminar Imagen",
"deleteImage_one": "Eliminar Imagen",
"deleteImage_many": "",
"deleteImage_other": "",
"deleteImageBin": "Las imágenes eliminadas se enviarán a la papelera de tu sistema operativo.",
"deleteImagePermanent": "Las imágenes eliminadas no se pueden restaurar.",
"assets": "Activos",

View File

@@ -82,7 +82,9 @@
"autoSwitchNewImages": "Passaggio automatico a nuove immagini",
"loadMore": "Carica altro",
"noImagesInGallery": "Nessuna immagine da visualizzare",
"deleteImage": "Elimina l'immagine",
"deleteImage_one": "Elimina l'immagine",
"deleteImage_many": "Elimina {{count}} immagini",
"deleteImage_other": "Elimina {{count}} immagini",
"deleteImagePermanent": "Le immagini eliminate non possono essere ripristinate.",
"deleteImageBin": "Le immagini eliminate verranno spostate nel cestino del tuo sistema operativo.",
"assets": "Risorse",

View File

@@ -90,7 +90,7 @@
"problemDeletingImages": "画像の削除中に問題が発生",
"drop": "ドロップ",
"dropOrUpload": "$t(gallery.drop) またはアップロード",
"deleteImage": "画像を削除",
"deleteImage_other": "画像を削除",
"deleteImageBin": "削除された画像はOSのゴミ箱に送られます。",
"deleteImagePermanent": "削除された画像は復元できません。",
"download": "ダウンロード",

View File

@@ -82,7 +82,7 @@
"drop": "드랍",
"problemDeletingImages": "이미지 삭제 중 발생한 문제",
"downloadSelection": "선택 항목 다운로드",
"deleteImage": "이미지 삭제",
"deleteImage_other": "이미지 삭제",
"currentlyInUse": "이 이미지는 현재 다음 기능에서 사용되고 있습니다:",
"dropOrUpload": "$t(gallery.drop) 또는 업로드",
"copy": "복사",

View File

@@ -42,7 +42,8 @@
"autoSwitchNewImages": "Wissel autom. naar nieuwe afbeeldingen",
"loadMore": "Laad meer",
"noImagesInGallery": "Geen afbeeldingen om te tonen",
"deleteImage": "Verwijder afbeelding",
"deleteImage_one": "Verwijder afbeelding",
"deleteImage_other": "",
"deleteImageBin": "Verwijderde afbeeldingen worden naar de prullenbak van je besturingssysteem gestuurd.",
"deleteImagePermanent": "Verwijderde afbeeldingen kunnen niet worden hersteld.",
"assets": "Eigen onderdelen",

View File

@@ -86,7 +86,9 @@
"noImagesInGallery": "Изображений нет",
"deleteImagePermanent": "Удаленные изображения невозможно восстановить.",
"deleteImageBin": "Удаленные изображения будут отправлены в корзину вашей операционной системы.",
"deleteImage": "Удалить изображение",
"deleteImage_one": "Удалить изображение",
"deleteImage_few": "",
"deleteImage_many": "",
"assets": "Ресурсы",
"autoAssignBoardOnClick": "Авто-назначение доски по клику",
"deleteSelection": "Удалить выделенное",

View File

@@ -298,7 +298,8 @@
"noImagesInGallery": "Gösterilecek Görsel Yok",
"autoSwitchNewImages": "Yeni Görseli Biter Bitmez Gör",
"currentlyInUse": "Bu görsel şurada kullanımda:",
"deleteImage": "Görseli Sil",
"deleteImage_one": "Görseli Sil",
"deleteImage_other": "",
"loadMore": "Daha Getir",
"setCurrentImage": "Çalışma Görseli Yap",
"unableToLoad": "Galeri Yüklenemedi",

View File

@@ -78,7 +78,7 @@
"autoSwitchNewImages": "自动切换到新图像",
"loadMore": "加载更多",
"noImagesInGallery": "无图像可用于显示",
"deleteImage": "删除图片",
"deleteImage_other": "删除图片",
"deleteImageBin": "被删除的图片会发送到你操作系统的回收站。",
"deleteImagePermanent": "删除的图片无法被恢复。",
"assets": "素材",

View File

@@ -9,7 +9,7 @@ import { useHotkeys } from 'react-hotkeys-hook';
export const useGlobalHotkeys = () => {
const dispatch = useAppDispatch();
const isModelManagerEnabled = useFeatureStatus('modelManager').isFeatureEnabled;
const isModelManagerEnabled = useFeatureStatus('modelManager');
const { queueBack, isDisabled: isDisabledQueueBack, isLoading: isLoadingQueueBack } = useQueueBack();
useHotkeys(

View File

@@ -21,6 +21,7 @@ import ControlAdapterShouldAutoConfig from './ControlAdapterShouldAutoConfig';
import ControlNetCanvasImageImports from './imports/ControlNetCanvasImageImports';
import { ParamControlAdapterBeginEnd } from './parameters/ParamControlAdapterBeginEnd';
import ParamControlAdapterControlMode from './parameters/ParamControlAdapterControlMode';
import ParamControlAdapterIPMethod from './parameters/ParamControlAdapterIPMethod';
import ParamControlAdapterProcessorSelect from './parameters/ParamControlAdapterProcessorSelect';
import ParamControlAdapterResizeMode from './parameters/ParamControlAdapterResizeMode';
import ParamControlAdapterWeight from './parameters/ParamControlAdapterWeight';
@@ -111,7 +112,8 @@ const ControlAdapterConfig = (props: { id: string; number: number }) => {
<Flex w="full" flexDir="column" gap={4}>
<Flex gap={8} w="full" alignItems="center">
<Flex flexDir="column" gap={2} h={32} w="full">
<Flex flexDir="column" gap={4} h={controlAdapterType === 'ip_adapter' ? 40 : 32} w="full">
<ParamControlAdapterIPMethod id={id} />
<ParamControlAdapterWeight id={id} />
<ParamControlAdapterBeginEnd id={id} />
</Flex>

View File

@@ -0,0 +1,63 @@
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { useControlAdapterIPMethod } from 'features/controlAdapters/hooks/useControlAdapterIPMethod';
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
import { controlAdapterIPMethodChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
import type { IPMethod } from 'features/controlAdapters/store/types';
import { isIPMethod } from 'features/controlAdapters/store/types';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
type Props = {
id: string;
};
const ParamControlAdapterIPMethod = ({ id }: Props) => {
const isEnabled = useControlAdapterIsEnabled(id);
const method = useControlAdapterIPMethod(id);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const options: { label: string; value: IPMethod }[] = useMemo(
() => [
{ label: t('controlnet.full'), value: 'full' },
{ label: `${t('controlnet.style')} (${t('common.beta')})`, value: 'style' },
{ label: `${t('controlnet.composition')} (${t('common.beta')})`, value: 'composition' },
],
[t]
);
const handleIPMethodChanged = useCallback<ComboboxOnChange>(
(v) => {
if (!isIPMethod(v?.value)) {
return;
}
dispatch(
controlAdapterIPMethodChanged({
id,
method: v.value,
})
);
},
[id, dispatch]
);
const value = useMemo(() => options.find((o) => o.value === method), [options, method]);
if (!method) {
return null;
}
return (
<FormControl>
<InformationalPopover feature="controlNetResizeMode">
<FormLabel>{t('controlnet.ipAdapterMethod')}</FormLabel>
</InformationalPopover>
<Combobox value={value} options={options} isDisabled={!isEnabled} onChange={handleIPMethodChanged} />
</FormControl>
);
};
export default memo(ParamControlAdapterIPMethod);

View File

@@ -0,0 +1,24 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import {
selectControlAdapterById,
selectControlAdaptersSlice,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { useMemo } from 'react';
export const useControlAdapterIPMethod = (id: string) => {
const selector = useMemo(
() =>
createMemoizedSelector(selectControlAdaptersSlice, (controlAdapters) => {
const cn = selectControlAdapterById(controlAdapters, id);
if (cn && cn?.type === 'ip_adapter') {
return cn.method;
}
}),
[id]
);
const method = useAppSelector(selector);
return method;
};

View File

@@ -21,6 +21,7 @@ import type {
ControlAdapterType,
ControlMode,
ControlNetConfig,
IPMethod,
RequiredControlAdapterProcessorNode,
ResizeMode,
T2IAdapterConfig,
@@ -245,6 +246,10 @@ export const controlAdaptersSlice = createSlice({
}
caAdapter.updateOne(state, { id, changes: { controlMode } });
},
controlAdapterIPMethodChanged: (state, action: PayloadAction<{ id: string; method: IPMethod }>) => {
const { id, method } = action.payload;
caAdapter.updateOne(state, { id, changes: { method } });
},
controlAdapterCLIPVisionModelChanged: (
state,
action: PayloadAction<{ id: string; clipVisionModel: CLIPVisionModel }>
@@ -390,6 +395,7 @@ export const {
controlAdapterIsEnabledChanged,
controlAdapterModelChanged,
controlAdapterCLIPVisionModelChanged,
controlAdapterIPMethodChanged,
controlAdapterWeightChanged,
controlAdapterBeginStepPctChanged,
controlAdapterEndStepPctChanged,

View File

@@ -210,6 +210,10 @@ const zResizeMode = z.enum(['just_resize', 'crop_resize', 'fill_resize', 'just_r
export type ResizeMode = z.infer<typeof zResizeMode>;
export const isResizeMode = (v: unknown): v is ResizeMode => zResizeMode.safeParse(v).success;
const zIPMethod = z.enum(['full', 'style', 'composition']);
export type IPMethod = z.infer<typeof zIPMethod>;
export const isIPMethod = (v: unknown): v is IPMethod => zIPMethod.safeParse(v).success;
export type ControlNetConfig = {
type: 'controlnet';
id: string;
@@ -253,6 +257,7 @@ export type IPAdapterConfig = {
model: ParameterIPAdapterModel | null;
clipVisionModel: CLIPVisionModel;
weight: number;
method: IPMethod;
beginStepPct: number;
endStepPct: number;
};

View File

@@ -46,6 +46,7 @@ export const initialIPAdapter: Omit<IPAdapterConfig, 'id'> = {
isEnabled: true,
controlImage: null,
model: null,
method: 'full',
clipVisionModel: 'ViT-H',
weight: 1,
beginStepPct: 0,

View File

@@ -32,7 +32,7 @@ const BoardContextMenu = ({ board, board_id, setBoardToDelete, children }: Props
const isSelectedForAutoAdd = useAppSelector(selectIsSelectedForAutoAdd);
const boardName = useBoardName(board_id);
const isBulkDownloadEnabled = useFeatureStatus('bulkDownload').isFeatureEnabled;
const isBulkDownloadEnabled = useFeatureStatus('bulkDownload');
const [bulkDownload] = useBulkDownloadImagesMutation();

View File

@@ -54,7 +54,7 @@ const CurrentImageButtons = () => {
const selection = useAppSelector((s) => s.gallery.selection);
const shouldDisableToolbarButtons = useAppSelector(selectShouldDisableToolbarButtons);
const isUpscalingEnabled = useFeatureStatus('upscaling').isFeatureEnabled;
const isUpscalingEnabled = useFeatureStatus('upscaling');
const isQueueMutationInProgress = useIsQueueMutationInProgress();
const toaster = useAppToaster();
const { t } = useTranslation();

View File

@@ -20,7 +20,7 @@ const MultipleSelectionMenuItems = () => {
const selection = useAppSelector((s) => s.gallery.selection);
const customStarUi = useStore($customStarUI);
const isBulkDownloadEnabled = useFeatureStatus('bulkDownload').isFeatureEnabled;
const isBulkDownloadEnabled = useFeatureStatus('bulkDownload');
const [starImages] = useStarImagesMutation();
const [unstarImages] = useUnstarImagesMutation();

View File

@@ -45,7 +45,7 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const toaster = useAppToaster();
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
const isCanvasEnabled = useFeatureStatus('unifiedCanvas');
const customStarUi = useStore($customStarUI);
const { downloadImage } = useDownloadImage();

View File

@@ -18,7 +18,7 @@ export const useMultiselect = (imageDTO?: ImageDTO) => {
[imageDTO?.image_name]
);
const isSelected = useAppSelector(selectIsSelected);
const isMultiSelectEnabled = useFeatureStatus('multiselect').isFeatureEnabled;
const isMultiSelectEnabled = useFeatureStatus('multiselect');
const handleClick = useCallback(
(e: MouseEvent<HTMLDivElement>) => {

View File

@@ -8,7 +8,7 @@ import ParamHrfStrength from './ParamHrfStrength';
import ParamHrfToggle from './ParamHrfToggle';
export const HrfSettings = memo(() => {
const isHRFFeatureEnabled = useFeatureStatus('hrf').isFeatureEnabled;
const isHRFFeatureEnabled = useFeatureStatus('hrf');
const hrfEnabled = useAppSelector((s) => s.hrf.hrfEnabled);
if (!isHRFFeatureEnabled) {

View File

@@ -386,6 +386,10 @@ const parseIPAdapter: MetadataParseFunc<IPAdapterConfigMetadata> = async (metada
.nullish()
.catch(null)
.parse(await getProperty(metadataItem, 'weight'));
const method = zIPAdapterField.shape.method
.nullish()
.catch(null)
.parse(await getProperty(metadataItem, 'method'));
const begin_step_percent = zIPAdapterField.shape.begin_step_percent
.nullish()
.catch(null)
@@ -403,6 +407,7 @@ const parseIPAdapter: MetadataParseFunc<IPAdapterConfigMetadata> = async (metada
clipVisionModel: 'ViT-H',
controlImage: image?.image_name ?? null,
weight: weight ?? initialIPAdapter.weight,
method: method ?? initialIPAdapter.method,
beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct,
endStepPct: end_step_percent ?? initialIPAdapter.endStepPct,
};

View File

@@ -10,7 +10,7 @@ const TOAST_ID = 'starterModels';
export const useStarterModelsToast = () => {
const { t } = useTranslation();
const isEnabled = useFeatureStatus('starterModels').isFeatureEnabled;
const isEnabled = useFeatureStatus('starterModels');
const [didToast, setDidToast] = useState(false);
const [mainModels, { data }] = useMainModels();
const toast = useToast();

View File

@@ -1,8 +1,9 @@
import { Flex, Text } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import type { CSSProperties } from 'react';
import { memo, useMemo } from 'react';
import type { EdgeProps } from 'reactflow';
import { BaseEdge, getBezierPath } from 'reactflow';
import { BaseEdge, EdgeLabelRenderer, getBezierPath } from 'reactflow';
import { makeEdgeSelector } from './util/makeEdgeSelector';
@@ -25,9 +26,10 @@ const InvocationDefaultEdge = ({
[source, sourceHandleId, target, targetHandleId, selected]
);
const { isSelected, shouldAnimate, stroke } = useAppSelector(selector);
const { isSelected, shouldAnimate, stroke, label } = useAppSelector(selector);
const shouldShowEdgeLabels = useAppSelector((s) => s.nodes.shouldShowEdgeLabels);
const [edgePath] = getBezierPath({
const [edgePath, labelX, labelY] = getBezierPath({
sourceX,
sourceY,
sourcePosition,
@@ -47,7 +49,33 @@ const InvocationDefaultEdge = ({
[isSelected, shouldAnimate, stroke]
);
return <BaseEdge path={edgePath} markerEnd={markerEnd} style={edgeStyles} />;
return (
<>
<BaseEdge path={edgePath} markerEnd={markerEnd} style={edgeStyles} />
{label && shouldShowEdgeLabels && (
<EdgeLabelRenderer>
<Flex
className="nodrag nopan"
pointerEvents="all"
position="absolute"
transform={`translate(-50%, -50%) translate(${labelX}px,${labelY}px)`}
bg="base.800"
borderRadius="base"
borderWidth={1}
borderColor={isSelected ? 'undefined' : 'transparent'}
opacity={isSelected ? 1 : 0.5}
py={1}
px={3}
shadow="md"
>
<Text size="sm" fontWeight="semibold" color={isSelected ? 'base.100' : 'base.300'}>
{label}
</Text>
</Flex>
</EdgeLabelRenderer>
)}
</>
);
};
export default memo(InvocationDefaultEdge);

View File

@@ -1,7 +1,7 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectFieldOutputTemplate } from 'features/nodes/store/selectors';
import { selectFieldOutputTemplate, selectNodeTemplate } from 'features/nodes/store/selectors';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { getFieldColor } from './getEdgeColor';
@@ -10,6 +10,7 @@ const defaultReturnValue = {
isSelected: false,
shouldAnimate: false,
stroke: colorTokenToCssVar('base.500'),
label: '',
};
export const makeEdgeSelector = (
@@ -19,25 +20,34 @@ export const makeEdgeSelector = (
targetHandleId: string | null | undefined,
selected?: boolean
) =>
createMemoizedSelector(selectNodesSlice, (nodes): { isSelected: boolean; shouldAnimate: boolean; stroke: string } => {
const sourceNode = nodes.nodes.find((node) => node.id === source);
const targetNode = nodes.nodes.find((node) => node.id === target);
createMemoizedSelector(
selectNodesSlice,
(nodes): { isSelected: boolean; shouldAnimate: boolean; stroke: string; label: string } => {
const sourceNode = nodes.nodes.find((node) => node.id === source);
const targetNode = nodes.nodes.find((node) => node.id === target);
const isInvocationToInvocationEdge = isInvocationNode(sourceNode) && isInvocationNode(targetNode);
const isInvocationToInvocationEdge = isInvocationNode(sourceNode) && isInvocationNode(targetNode);
const isSelected = Boolean(sourceNode?.selected || targetNode?.selected || selected);
if (!sourceNode || !sourceHandleId) {
return defaultReturnValue;
const isSelected = Boolean(sourceNode?.selected || targetNode?.selected || selected);
if (!sourceNode || !sourceHandleId || !targetNode || !targetHandleId) {
return defaultReturnValue;
}
const outputFieldTemplate = selectFieldOutputTemplate(nodes, sourceNode.id, sourceHandleId);
const sourceType = isInvocationToInvocationEdge ? outputFieldTemplate?.type : undefined;
const stroke = sourceType && nodes.shouldColorEdges ? getFieldColor(sourceType) : colorTokenToCssVar('base.500');
const sourceNodeTemplate = selectNodeTemplate(nodes, sourceNode.id);
const targetNodeTemplate = selectNodeTemplate(nodes, targetNode.id);
const label = `${sourceNodeTemplate?.title || sourceNode.data?.label} -> ${targetNodeTemplate?.title || targetNode.data?.label}`;
return {
isSelected,
shouldAnimate: nodes.shouldAnimateEdges && isSelected,
stroke,
label,
};
}
const outputFieldTemplate = selectFieldOutputTemplate(nodes, sourceNode.id, sourceHandleId);
const sourceType = isInvocationToInvocationEdge ? outputFieldTemplate?.type : undefined;
const stroke = sourceType && nodes.shouldColorEdges ? getFieldColor(sourceType) : colorTokenToCssVar('base.500');
return {
isSelected,
shouldAnimate: nodes.shouldAnimateEdges && isSelected,
stroke,
};
});
);

View File

@@ -16,7 +16,7 @@ const props: ChakraProps = { w: 'unset' };
const InvocationNodeFooter = ({ nodeId }: Props) => {
const hasImageOutput = useHasImageOutput(nodeId);
const isCacheEnabled = useFeatureStatus('invocationCache').isFeatureEnabled;
const isCacheEnabled = useFeatureStatus('invocationCache');
return (
<Flex
className={DRAG_HANDLE_CLASSNAME}

View File

@@ -24,6 +24,7 @@ import {
selectNodesSlice,
shouldAnimateEdgesChanged,
shouldColorEdgesChanged,
shouldShowEdgeLabelsChanged,
shouldSnapToGridChanged,
shouldValidateGraphChanged,
} from 'features/nodes/store/nodesSlice';
@@ -35,12 +36,20 @@ import { SelectionMode } from 'reactflow';
const formLabelProps: FormLabelProps = { flexGrow: 1 };
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
const { shouldAnimateEdges, shouldValidateGraph, shouldSnapToGrid, shouldColorEdges, selectionMode } = nodes;
const {
shouldAnimateEdges,
shouldValidateGraph,
shouldSnapToGrid,
shouldColorEdges,
shouldShowEdgeLabels,
selectionMode,
} = nodes;
return {
shouldAnimateEdges,
shouldValidateGraph,
shouldSnapToGrid,
shouldColorEdges,
shouldShowEdgeLabels,
selectionModeIsChecked: selectionMode === SelectionMode.Full,
};
});
@@ -52,8 +61,14 @@ type Props = {
const WorkflowEditorSettings = ({ children }: Props) => {
const { isOpen, onOpen, onClose } = useDisclosure();
const dispatch = useAppDispatch();
const { shouldAnimateEdges, shouldValidateGraph, shouldSnapToGrid, shouldColorEdges, selectionModeIsChecked } =
useAppSelector(selector);
const {
shouldAnimateEdges,
shouldValidateGraph,
shouldSnapToGrid,
shouldColorEdges,
shouldShowEdgeLabels,
selectionModeIsChecked,
} = useAppSelector(selector);
const handleChangeShouldValidate = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
@@ -90,6 +105,13 @@ const WorkflowEditorSettings = ({ children }: Props) => {
[dispatch]
);
const handleChangeShouldShowEdgeLabels = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(shouldShowEdgeLabelsChanged(e.target.checked));
},
[dispatch]
);
const { t } = useTranslation();
return (
@@ -137,6 +159,14 @@ const WorkflowEditorSettings = ({ children }: Props) => {
<FormHelperText>{t('nodes.fullyContainNodesHelp')}</FormHelperText>
</FormControl>
<Divider />
<FormControl>
<Flex w="full">
<FormLabel>{t('nodes.showEdgeLabels')}</FormLabel>
<Switch isChecked={shouldShowEdgeLabels} onChange={handleChangeShouldShowEdgeLabels} />
</Flex>
<FormHelperText>{t('nodes.showEdgeLabelsHelp')}</FormHelperText>
</FormControl>
<Divider />
<Heading size="sm" pt={4}>
{t('common.advanced')}
</Heading>

View File

@@ -1,5 +1,5 @@
import { createSelector } from '@reduxjs/toolkit';
import { EMPTY_ARRAY } from 'app/store/constants';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectNodeTemplate } from 'features/nodes/store/selectors';
@@ -10,7 +10,7 @@ import { useMemo } from 'react';
export const useOutputFieldNames = (nodeId: string) => {
const selector = useMemo(
() =>
createSelector(selectNodesSlice, (nodes) => {
createMemoizedSelector(selectNodesSlice, (nodes) => {
const template = selectNodeTemplate(nodes, nodeId);
if (!template) {
return EMPTY_ARRAY;

View File

@@ -5,8 +5,7 @@ import { useHasImageOutput } from './useHasImageOutput';
export const useWithFooter = (nodeId: string) => {
const hasImageOutput = useHasImageOutput(nodeId);
const isCacheEnabled = useFeatureStatus('invocationCache').isFeatureEnabled;
const isCacheEnabled = useFeatureStatus('invocationCache');
const withFooter = useMemo(() => hasImageOutput || isCacheEnabled, [hasImageOutput, isCacheEnabled]);
return withFooter;
};

View File

@@ -103,6 +103,7 @@ const initialNodesState: NodesState = {
shouldAnimateEdges: true,
shouldSnapToGrid: false,
shouldColorEdges: true,
shouldShowEdgeLabels: false,
isAddNodePopoverOpen: false,
nodeOpacity: 1,
selectedNodes: [],
@@ -549,6 +550,9 @@ export const nodesSlice = createSlice({
shouldAnimateEdgesChanged: (state, action: PayloadAction<boolean>) => {
state.shouldAnimateEdges = action.payload;
},
shouldShowEdgeLabelsChanged: (state, action: PayloadAction<boolean>) => {
state.shouldShowEdgeLabels = action.payload;
},
shouldSnapToGridChanged: (state, action: PayloadAction<boolean>) => {
state.shouldSnapToGrid = action.payload;
},
@@ -831,6 +835,7 @@ export const {
viewportChanged,
edgeAdded,
nodeTemplatesBuilt,
shouldShowEdgeLabelsChanged,
} = nodesSlice.actions;
// This is used for tracking `state.workflow.isTouched`

View File

@@ -32,6 +32,7 @@ export type NodesState = {
isAddNodePopoverOpen: boolean;
addNewNodePosition: XYPosition | null;
selectionMode: SelectionMode;
shouldShowEdgeLabels: boolean;
};
export type WorkflowMode = 'edit' | 'view';

View File

@@ -109,6 +109,7 @@ export const zIPAdapterField = z.object({
image: zImageField,
ip_adapter_model: zModelIdentifierField,
weight: z.number(),
method: z.enum(['full', 'style', 'composition']),
begin_step_percent: z.number().optional(),
end_step_percent: z.number().optional(),
});

View File

@@ -48,7 +48,7 @@ export const addIPAdapterToLinearGraph = async (
if (!ipAdapter.model) {
return;
}
const { id, weight, model, clipVisionModel, beginStepPct, endStepPct, controlImage } = ipAdapter;
const { id, weight, model, clipVisionModel, method, beginStepPct, endStepPct, controlImage } = ipAdapter;
assert(controlImage, 'IP Adapter image is required');
@@ -57,6 +57,7 @@ export const addIPAdapterToLinearGraph = async (
type: 'ip_adapter',
is_intermediate: true,
weight: weight,
method: method,
ip_adapter_model: model,
clip_vision_model: clipVisionModel,
begin_step_percent: beginStepPct,
@@ -84,7 +85,7 @@ export const addIPAdapterToLinearGraph = async (
};
const buildIPAdapterMetadata = (ipAdapter: IPAdapterConfig): S['IPAdapterMetadataField'] => {
const { controlImage, beginStepPct, endStepPct, model, clipVisionModel, weight } = ipAdapter;
const { controlImage, beginStepPct, endStepPct, model, clipVisionModel, method, weight } = ipAdapter;
assert(model, 'IP Adapter model is required');
@@ -102,6 +103,7 @@ const buildIPAdapterMetadata = (ipAdapter: IPAdapterConfig): S['IPAdapterMetadat
ip_adapter_model: model,
clip_vision_model: clipVisionModel,
weight,
method,
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
image,

View File

@@ -1,24 +1,18 @@
import { Box, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIColorPicker from 'common/components/IAIColorPicker';
import { selectGenerationSlice, setInfillColorValue } from 'features/parameters/store/generationSlice';
import { memo, useCallback, useMemo } from 'react';
import { memo, useCallback } from 'react';
import type { RgbaColor } from 'react-colorful';
import { useTranslation } from 'react-i18next';
const selectInfillColor = createMemoizedSelector(selectGenerationSlice, (generation) => generation.infillColorValue);
const ParamInfillColorOptions = () => {
const dispatch = useAppDispatch();
const selector = useMemo(
() =>
createSelector(selectGenerationSlice, (generation) => ({
infillColor: generation.infillColorValue,
})),
[]
);
const { infillColor } = useAppSelector(selector);
const infillColor = useAppSelector(selectInfillColor);
const infillMethod = useAppSelector((s) => s.generation.infillMethod);

View File

@@ -1,35 +1,23 @@
import { Box, CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIColorPicker from 'common/components/IAIColorPicker';
import {
selectGenerationSlice,
setInfillMosaicMaxColor,
setInfillMosaicMinColor,
setInfillMosaicTileHeight,
setInfillMosaicTileWidth,
} from 'features/parameters/store/generationSlice';
import { memo, useCallback, useMemo } from 'react';
import { memo, useCallback } from 'react';
import type { RgbaColor } from 'react-colorful';
import { useTranslation } from 'react-i18next';
const ParamInfillMosaicTileSize = () => {
const dispatch = useAppDispatch();
const selector = useMemo(
() =>
createSelector(selectGenerationSlice, (generation) => ({
infillMosaicTileWidth: generation.infillMosaicTileWidth,
infillMosaicTileHeight: generation.infillMosaicTileHeight,
infillMosaicMinColor: generation.infillMosaicMinColor,
infillMosaicMaxColor: generation.infillMosaicMaxColor,
})),
[]
);
const { infillMosaicTileWidth, infillMosaicTileHeight, infillMosaicMinColor, infillMosaicMaxColor } =
useAppSelector(selector);
const infillMosaicTileWidth = useAppSelector((s) => s.generation.infillMosaicTileWidth);
const infillMosaicTileHeight = useAppSelector((s) => s.generation.infillMosaicTileHeight);
const infillMosaicMinColor = useAppSelector((s) => s.generation.infillMosaicMinColor);
const infillMosaicMaxColor = useAppSelector((s) => s.generation.infillMosaicMaxColor);
const infillMethod = useAppSelector((s) => s.generation.infillMethod);
const { t } = useTranslation();

View File

@@ -27,8 +27,8 @@ export const QueueActionsMenuButton = memo(() => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const clearQueueDisclosure = useDisclosure();
const isPauseEnabled = useFeatureStatus('pauseQueue').isFeatureEnabled;
const isResumeEnabled = useFeatureStatus('resumeQueue').isFeatureEnabled;
const isPauseEnabled = useFeatureStatus('pauseQueue');
const isResumeEnabled = useFeatureStatus('resumeQueue');
const { queueSize } = useGetQueueStatusQuery(undefined, {
selectFromResult: (res) => ({
queueSize: res.data ? res.data.queue.pending + res.data.queue.in_progress : 0,

View File

@@ -9,7 +9,7 @@ import { InvokeQueueBackButton } from './InvokeQueueBackButton';
import { QueueActionsMenuButton } from './QueueActionsMenuButton';
const QueueControls = () => {
const isPrependEnabled = useFeatureStatus('prependQueue').isFeatureEnabled;
const isPrependEnabled = useFeatureStatus('prependQueue');
return (
<Flex w="full" position="relative" borderRadius="base" gap={2} pt={2} flexDir="column">
<ButtonGroup size="lg" isAttached={false}>

View File

@@ -8,7 +8,7 @@ import QueueStatus from './QueueStatus';
import QueueTabQueueControls from './QueueTabQueueControls';
const QueueTabContent = () => {
const isInvocationCacheEnabled = useFeatureStatus('invocationCache').isFeatureEnabled;
const isInvocationCacheEnabled = useFeatureStatus('invocationCache');
return (
<Flex borderRadius="base" w="full" h="full" flexDir="column" gap={2}>

View File

@@ -8,8 +8,8 @@ import PruneQueueButton from './PruneQueueButton';
import ResumeProcessorButton from './ResumeProcessorButton';
const QueueTabQueueControls = () => {
const isPauseEnabled = useFeatureStatus('pauseQueue').isFeatureEnabled;
const isResumeEnabled = useFeatureStatus('resumeQueue').isFeatureEnabled;
const isPauseEnabled = useFeatureStatus('pauseQueue');
const isResumeEnabled = useFeatureStatus('resumeQueue');
return (
<Flex layerStyle="first" borderRadius="base" p={2} gap={2}>
{isPauseEnabled || isResumeEnabled ? (

View File

@@ -13,7 +13,7 @@ export const useQueueFront = () => {
const [_, { isLoading }] = useEnqueueBatchMutation({
fixedCacheKey: 'enqueueBatch',
});
const prependEnabled = useFeatureStatus('prependQueue').isFeatureEnabled;
const prependEnabled = useFeatureStatus('prependQueue');
const isDisabled = useMemo(() => {
return !isReady || !prependEnabled;

View File

@@ -62,7 +62,7 @@ const selector = createMemoizedSelector(selectControlAdaptersSlice, (controlAdap
export const ControlSettingsAccordion: React.FC = memo(() => {
const { t } = useTranslation();
const { controlAdapterIds, badges } = useAppSelector(selector);
const isControlNetDisabled = useFeatureStatus('controlNet').isFeatureDisabled;
const isControlNetEnabled = useFeatureStatus('controlNet');
const { isOpen, onToggle } = useStandaloneAccordionToggle({
id: 'control-settings',
defaultIsOpen: true,
@@ -71,7 +71,7 @@ export const ControlSettingsAccordion: React.FC = memo(() => {
const [addIPAdapter, isAddIPAdapterDisabled] = useAddControlAdapter('ip_adapter');
const [addT2IAdapter, isAddT2IAdapterDisabled] = useAddControlAdapter('t2i_adapter');
if (isControlNetDisabled) {
if (!isControlNetEnabled) {
return null;
}

View File

@@ -40,7 +40,7 @@ export const SettingsLanguageSelect = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const language = useAppSelector((s) => s.system.language);
const isLocalizationEnabled = useFeatureStatus('localization').isFeatureEnabled;
const isLocalizationEnabled = useFeatureStatus('localization');
const value = useMemo(() => options.find((o) => o.value === language), [language]);

View File

@@ -23,9 +23,9 @@ const SettingsMenu = () => {
const { isOpen, onOpen, onClose } = useDisclosure();
useGlobalMenuClose(onClose);
const isBugLinkEnabled = useFeatureStatus('bugLink').isFeatureEnabled;
const isDiscordLinkEnabled = useFeatureStatus('discordLink').isFeatureEnabled;
const isGithubLinkEnabled = useFeatureStatus('githubLink').isFeatureEnabled;
const isBugLinkEnabled = useFeatureStatus('bugLink');
const isDiscordLinkEnabled = useFeatureStatus('discordLink');
const isGithubLinkEnabled = useFeatureStatus('githubLink');
return (
<Menu isOpen={isOpen} onOpen={onOpen} onClose={onClose}>

View File

@@ -1,32 +1,24 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import type { AppFeature, SDFeature } from 'app/types/invokeai';
import { selectConfigSlice } from 'features/system/store/configSlice';
import type { InvokeTabName } from 'features/ui/store/tabMap';
import { useMemo } from 'react';
export const useFeatureStatus = (feature: AppFeature | SDFeature | InvokeTabName) => {
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
const disabledFeatures = useAppSelector((s) => s.config.disabledFeatures);
const disabledSDFeatures = useAppSelector((s) => s.config.disabledSDFeatures);
const isFeatureDisabled = useMemo(
const selectIsFeatureEnabled = useMemo(
() =>
disabledFeatures.includes(feature as AppFeature) ||
disabledSDFeatures.includes(feature as SDFeature) ||
disabledTabs.includes(feature as InvokeTabName),
[disabledFeatures, disabledSDFeatures, disabledTabs, feature]
createSelector(selectConfigSlice, (config) => {
return !(
config.disabledFeatures.includes(feature as AppFeature) ||
config.disabledSDFeatures.includes(feature as SDFeature) ||
config.disabledTabs.includes(feature as InvokeTabName)
);
}),
[feature]
);
const isFeatureEnabled = useMemo(
() =>
!(
disabledFeatures.includes(feature as AppFeature) ||
disabledSDFeatures.includes(feature as SDFeature) ||
disabledTabs.includes(feature as InvokeTabName)
),
[disabledFeatures, disabledSDFeatures, disabledTabs, feature]
);
const isFeatureEnabled = useAppSelector(selectIsFeatureEnabled);
return { isFeatureDisabled, isFeatureEnabled };
return isFeatureEnabled;
};

File diff suppressed because one or more lines are too long

View File

@@ -1 +1 @@
__version__ = "4.0.4"
__version__ = "4.1.0"

View File

@@ -0,0 +1,132 @@
"""
Test abstract device class.
"""
from unittest.mock import patch
import pytest
import torch
from invokeai.app.services.config import get_config
from invokeai.backend.util.devices import TorchDevice, choose_precision, choose_torch_device, torch_dtype
devices = ["cpu", "cuda:0", "cuda:1", "mps"]
device_types_cpu = [("cpu", torch.float32), ("cuda:0", torch.float32), ("mps", torch.float32)]
device_types_cuda = [("cpu", torch.float32), ("cuda:0", torch.float16), ("mps", torch.float32)]
device_types_mps = [("cpu", torch.float32), ("cuda:0", torch.float32), ("mps", torch.float16)]
@pytest.mark.parametrize("device_name", devices)
def test_device_choice(device_name):
config = get_config()
config.device = device_name
torch_device = TorchDevice.choose_torch_device()
assert torch_device == torch.device(device_name)
@pytest.mark.parametrize("device_dtype_pair", device_types_cpu)
def test_device_dtype_cpu(device_dtype_pair):
with (
patch("torch.cuda.is_available", return_value=False),
patch("torch.backends.mps.is_available", return_value=False),
):
device_name, dtype = device_dtype_pair
config = get_config()
config.device = device_name
torch_dtype = TorchDevice.choose_torch_dtype()
assert torch_dtype == dtype
@pytest.mark.parametrize("device_dtype_pair", device_types_cuda)
def test_device_dtype_cuda(device_dtype_pair):
with (
patch("torch.cuda.is_available", return_value=True),
patch("torch.cuda.get_device_name", return_value="RTX4070"),
patch("torch.backends.mps.is_available", return_value=False),
):
device_name, dtype = device_dtype_pair
config = get_config()
config.device = device_name
torch_dtype = TorchDevice.choose_torch_dtype()
assert torch_dtype == dtype
@pytest.mark.parametrize("device_dtype_pair", device_types_mps)
def test_device_dtype_mps(device_dtype_pair):
with (
patch("torch.cuda.is_available", return_value=False),
patch("torch.backends.mps.is_available", return_value=True),
):
device_name, dtype = device_dtype_pair
config = get_config()
config.device = device_name
torch_dtype = TorchDevice.choose_torch_dtype()
assert torch_dtype == dtype
@pytest.mark.parametrize("device_dtype_pair", device_types_cuda)
def test_device_dtype_override(device_dtype_pair):
with (
patch("torch.cuda.get_device_name", return_value="RTX4070"),
patch("torch.cuda.is_available", return_value=True),
patch("torch.backends.mps.is_available", return_value=False),
):
device_name, dtype = device_dtype_pair
config = get_config()
config.device = device_name
config.precision = "float32"
torch_dtype = TorchDevice.choose_torch_dtype()
assert torch_dtype == torch.float32
def test_normalize():
assert (
TorchDevice.normalize("cuda") == torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cuda")
)
assert (
TorchDevice.normalize("cuda:0") == torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cuda")
)
assert (
TorchDevice.normalize("cuda:1") == torch.device("cuda:1") if torch.cuda.is_available() else torch.device("cuda")
)
assert TorchDevice.normalize("mps") == torch.device("mps")
assert TorchDevice.normalize("cpu") == torch.device("cpu")
@pytest.mark.parametrize("device_name", devices)
def test_legacy_device_choice(device_name):
config = get_config()
config.device = device_name
with pytest.deprecated_call():
torch_device = choose_torch_device()
assert torch_device == torch.device(device_name)
@pytest.mark.parametrize("device_dtype_pair", device_types_cpu)
def test_legacy_device_dtype_cpu(device_dtype_pair):
with (
patch("torch.cuda.is_available", return_value=False),
patch("torch.backends.mps.is_available", return_value=False),
patch("torch.cuda.get_device_name", return_value="RTX9090"),
):
device_name, dtype = device_dtype_pair
config = get_config()
config.device = device_name
with pytest.deprecated_call():
torch_device = choose_torch_device()
returned_dtype = torch_dtype(torch_device)
assert returned_dtype == dtype
def test_legacy_precision_name():
config = get_config()
config.precision = "auto"
with (
pytest.deprecated_call(),
patch("torch.cuda.is_available", return_value=True),
patch("torch.backends.mps.is_available", return_value=True),
patch("torch.cuda.get_device_name", return_value="RTX9090"),
):
assert "float16" == choose_precision(torch.device("cuda"))
assert "float16" == choose_precision(torch.device("mps"))
assert "float32" == choose_precision(torch.device("cpu"))