mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-15 08:28:14 -05:00
Compare commits
33 Commits
lstein/tes
...
v4.1.0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ac1071a5e5 | ||
|
|
5295a398f3 | ||
|
|
0c7283c82d | ||
|
|
73ad173c74 | ||
|
|
c828a4e59f | ||
|
|
6bab040d24 | ||
|
|
f46bbaf8c4 | ||
|
|
fce6b3e44c | ||
|
|
d27907cc6d | ||
|
|
7ee3fef2db | ||
|
|
b39ce642b6 | ||
|
|
a148c4322c | ||
|
|
f6b7bc5d98 | ||
|
|
5f6c6abf9c | ||
|
|
cd76a31a8f | ||
|
|
e93f4d632d | ||
|
|
5a8489bbfc | ||
|
|
a24c9d0f7a | ||
|
|
7a92afc117 | ||
|
|
b508945b11 | ||
|
|
8426f1e7b2 | ||
|
|
9cb0f63c44 | ||
|
|
2d5786d3bb | ||
|
|
27466ffa1a | ||
|
|
f50b156511 | ||
|
|
9fc73743b2 | ||
|
|
d4393e4170 | ||
|
|
145a0b029e | ||
|
|
f2506cc769 | ||
|
|
7a67fd6a06 | ||
|
|
af36fe8c1e | ||
|
|
e9f16ac8c7 | ||
|
|
6ea183f0d4 |
@@ -28,7 +28,7 @@ from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||
from invokeai.backend.util.devices import get_torch_device_name
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
from ..backend.util.logging import InvokeAILogger
|
||||
from .api.dependencies import ApiDependencies
|
||||
@@ -63,7 +63,7 @@ logger = InvokeAILogger.get_logger(config=app_config)
|
||||
mimetypes.add_type("application/javascript", ".js")
|
||||
mimetypes.add_type("text/css", ".css")
|
||||
|
||||
torch_device_name = get_torch_device_name()
|
||||
torch_device_name = TorchDevice.get_torch_device_name()
|
||||
logger.info(f"Using torch device: {torch_device_name}")
|
||||
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
ConditioningFieldData,
|
||||
SDXLConditioningInfo,
|
||||
)
|
||||
from invokeai.backend.util.devices import torch_dtype
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||
from .model import CLIPField
|
||||
@@ -99,7 +99,7 @@ class CompelInvocation(BaseInvocation):
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
textual_inversion_manager=ti_manager,
|
||||
dtype_for_device_getter=torch_dtype,
|
||||
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
|
||||
truncate_long_prompts=False,
|
||||
)
|
||||
|
||||
@@ -193,7 +193,7 @@ class SDXLPromptInvocationBase:
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
textual_inversion_manager=ti_manager,
|
||||
dtype_for_device_getter=torch_dtype,
|
||||
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
|
||||
truncate_long_prompts=False, # TODO:
|
||||
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
|
||||
requires_pooled=get_pooled,
|
||||
|
||||
@@ -4,20 +4,8 @@ from typing import List, Literal, Optional, Union
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
OutputField,
|
||||
TensorField,
|
||||
UIType,
|
||||
)
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, TensorField, UIType
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
@@ -36,6 +24,7 @@ class IPAdapterField(BaseModel):
|
||||
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model to use.")
|
||||
image_encoder_model: ModelIdentifierField = Field(description="The name of the CLIP image encoder model.")
|
||||
weight: Union[float, List[float]] = Field(default=1, description="The weight given to the IP-Adapter.")
|
||||
target_blocks: List[str] = Field(default=[], description="The IP Adapter blocks to apply")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
|
||||
)
|
||||
@@ -69,7 +58,7 @@ class IPAdapterOutput(BaseInvocationOutput):
|
||||
CLIP_VISION_MODEL_MAP = {"ViT-H": "ip_adapter_sd_image_encoder", "ViT-G": "ip_adapter_sdxl_image_encoder"}
|
||||
|
||||
|
||||
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.3.0")
|
||||
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.4.0")
|
||||
class IPAdapterInvocation(BaseInvocation):
|
||||
"""Collects IP-Adapter info to pass to other nodes."""
|
||||
|
||||
@@ -90,6 +79,9 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
weight: Union[float, List[float]] = InputField(
|
||||
default=1, description="The weight given to the IP-Adapter", title="Weight"
|
||||
)
|
||||
method: Literal["full", "style", "composition"] = InputField(
|
||||
default="full", description="The method to apply the IP-Adapter"
|
||||
)
|
||||
begin_step_percent: float = InputField(
|
||||
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
|
||||
)
|
||||
@@ -124,12 +116,32 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
|
||||
image_encoder_model = self._get_image_encoder(context, image_encoder_model_name)
|
||||
|
||||
if self.method == "style":
|
||||
if ip_adapter_info.base == "sd-1":
|
||||
target_blocks = ["up_blocks.1"]
|
||||
elif ip_adapter_info.base == "sdxl":
|
||||
target_blocks = ["up_blocks.0.attentions.1"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported IP-Adapter base type: '{ip_adapter_info.base}'.")
|
||||
elif self.method == "composition":
|
||||
if ip_adapter_info.base == "sd-1":
|
||||
target_blocks = ["down_blocks.2", "mid_block"]
|
||||
elif ip_adapter_info.base == "sdxl":
|
||||
target_blocks = ["down_blocks.2.attentions.1"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported IP-Adapter base type: '{ip_adapter_info.base}'.")
|
||||
elif self.method == "full":
|
||||
target_blocks = ["block"]
|
||||
else:
|
||||
raise ValueError(f"Unexpected IP-Adapter method: '{self.method}'.")
|
||||
|
||||
return IPAdapterOutput(
|
||||
ip_adapter=IPAdapterField(
|
||||
image=self.image,
|
||||
ip_adapter_model=self.ip_adapter_model,
|
||||
image_encoder_model=ModelIdentifierField.from_config(image_encoder_model),
|
||||
weight=self.weight,
|
||||
target_blocks=target_blocks,
|
||||
begin_step_percent=self.begin_step_percent,
|
||||
end_step_percent=self.end_step_percent,
|
||||
mask=self.mask,
|
||||
|
||||
@@ -72,15 +72,12 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||
image_resized_to_grid_as_tensor,
|
||||
)
|
||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from ...backend.util.devices import choose_precision, choose_torch_device
|
||||
from ...backend.util.devices import TorchDevice
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||
from .controlnet_image_processors import ControlField
|
||||
from .model import ModelIdentifierField, UNetField, VAEField
|
||||
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
from torch import mps
|
||||
|
||||
DEFAULT_PRECISION = choose_precision(choose_torch_device())
|
||||
DEFAULT_PRECISION = TorchDevice.choose_torch_dtype()
|
||||
|
||||
|
||||
@invocation_output("scheduler_output")
|
||||
@@ -682,6 +679,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
IPAdapterData(
|
||||
ip_adapter_model=ip_adapter_model,
|
||||
weight=single_ip_adapter.weight,
|
||||
target_blocks=single_ip_adapter.target_blocks,
|
||||
begin_step_percent=single_ip_adapter.begin_step_percent,
|
||||
end_step_percent=single_ip_adapter.end_step_percent,
|
||||
ip_adapter_conditioning=IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds),
|
||||
@@ -959,9 +957,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
result_latents = result_latents.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
name = context.tensors.save(tensor=result_latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)
|
||||
@@ -1028,9 +1024,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
vae.disable_tiling()
|
||||
|
||||
# clear memory as vae decode can request a lot
|
||||
torch.cuda.empty_cache()
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
with torch.inference_mode():
|
||||
# copied from diffusers pipeline
|
||||
@@ -1042,9 +1036,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
image = VaeImageProcessor.numpy_to_pil(np_image)[0]
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
image_dto = context.images.save(image=image)
|
||||
|
||||
@@ -1083,9 +1075,7 @@ class ResizeLatentsInvocation(BaseInvocation):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
# TODO:
|
||||
device = choose_torch_device()
|
||||
device = TorchDevice.choose_torch_device()
|
||||
|
||||
resized_latents = torch.nn.functional.interpolate(
|
||||
latents.to(device),
|
||||
@@ -1096,9 +1086,8 @@ class ResizeLatentsInvocation(BaseInvocation):
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
resized_latents = resized_latents.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
if device == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
name = context.tensors.save(tensor=resized_latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||
@@ -1125,8 +1114,7 @@ class ScaleLatentsInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
# TODO:
|
||||
device = choose_torch_device()
|
||||
device = TorchDevice.choose_torch_device()
|
||||
|
||||
# resizing
|
||||
resized_latents = torch.nn.functional.interpolate(
|
||||
@@ -1138,9 +1126,7 @@ class ScaleLatentsInvocation(BaseInvocation):
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
resized_latents = resized_latents.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
if device == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
name = context.tensors.save(tensor=resized_latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||
@@ -1272,8 +1258,7 @@ class BlendLatentsInvocation(BaseInvocation):
|
||||
if latents_a.shape != latents_b.shape:
|
||||
raise Exception("Latents to blend must be the same size.")
|
||||
|
||||
# TODO:
|
||||
device = choose_torch_device()
|
||||
device = TorchDevice.choose_torch_device()
|
||||
|
||||
def slerp(
|
||||
t: Union[float, npt.NDArray[Any]], # FIXME: maybe use np.float32 here?
|
||||
@@ -1326,9 +1311,8 @@ class BlendLatentsInvocation(BaseInvocation):
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
blended_latents = blended_latents.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
if device == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
name = context.tensors.save(tensor=blended_latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=blended_latents)
|
||||
|
||||
@@ -36,6 +36,7 @@ class IPAdapterMetadataField(BaseModel):
|
||||
image: ImageField = Field(description="The IP-Adapter image prompt.")
|
||||
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model.")
|
||||
clip_vision_model: Literal["ViT-H", "ViT-G"] = Field(description="The CLIP Vision model")
|
||||
method: Literal["full", "style", "composition"] = Field(description="Method to apply IP Weights with")
|
||||
weight: Union[float, list[float]] = Field(description="The weight given to the IP-Adapter")
|
||||
begin_step_percent: float = Field(description="When the IP-Adapter is first applied (% of total steps)")
|
||||
end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)")
|
||||
|
||||
@@ -9,7 +9,7 @@ from invokeai.app.invocations.fields import FieldDescriptions, InputField, Laten
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.misc import SEED_MAX
|
||||
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||
from ...backend.util.devices import TorchDevice
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
@@ -46,7 +46,7 @@ def get_noise(
|
||||
height // downsampling_factor,
|
||||
width // downsampling_factor,
|
||||
],
|
||||
dtype=torch_dtype(device),
|
||||
dtype=TorchDevice.choose_torch_dtype(device=device),
|
||||
device=noise_device_type,
|
||||
generator=generator,
|
||||
).to("cpu")
|
||||
@@ -111,14 +111,14 @@ class NoiseInvocation(BaseInvocation):
|
||||
|
||||
@field_validator("seed", mode="before")
|
||||
def modulo_seed(cls, v):
|
||||
"""Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
|
||||
"""Return the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
|
||||
return v % (SEED_MAX + 1)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> NoiseOutput:
|
||||
noise = get_noise(
|
||||
width=self.width,
|
||||
height=self.height,
|
||||
device=choose_torch_device(),
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
seed=self.seed,
|
||||
use_cpu=self.use_cpu,
|
||||
)
|
||||
|
||||
@@ -4,7 +4,6 @@ from typing import Literal
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from pydantic import ConfigDict
|
||||
|
||||
@@ -14,7 +13,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
||||
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
from .baseinvocation import BaseInvocation, invocation
|
||||
from .fields import InputField, WithBoard, WithMetadata
|
||||
@@ -35,9 +34,6 @@ ESRGAN_MODEL_URLS: dict[str, str] = {
|
||||
"RealESRGAN_x2plus.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
||||
}
|
||||
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
from torch import mps
|
||||
|
||||
|
||||
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.2")
|
||||
class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
@@ -120,9 +116,7 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
upscaled_image = upscaler.upscale(cv2_image)
|
||||
pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA")
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
image_dto = context.images.save(image=pil_image)
|
||||
|
||||
|
||||
@@ -27,12 +27,12 @@ DEFAULT_RAM_CACHE = 10.0
|
||||
DEFAULT_VRAM_CACHE = 0.25
|
||||
DEFAULT_CONVERT_CACHE = 20.0
|
||||
DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"]
|
||||
PRECISION = Literal["auto", "float16", "bfloat16", "float32", "autocast"]
|
||||
PRECISION = Literal["auto", "float16", "bfloat16", "float32"]
|
||||
ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
|
||||
ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
|
||||
LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"]
|
||||
LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
|
||||
CONFIG_SCHEMA_VERSION = "4.0.0"
|
||||
CONFIG_SCHEMA_VERSION = "4.0.1"
|
||||
|
||||
|
||||
def get_default_ram_cache_size() -> float:
|
||||
@@ -105,7 +105,7 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
lazy_offload: Keep models in VRAM until their space is needed.
|
||||
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
|
||||
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
|
||||
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`, `autocast`
|
||||
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`
|
||||
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
|
||||
attention_type: Attention type.<br>Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp`
|
||||
attention_slice_size: Slice size, valid when attention_type=="sliced".<br>Valid values: `auto`, `balanced`, `max`, `1`, `2`, `3`, `4`, `5`, `6`, `7`, `8`
|
||||
@@ -370,6 +370,9 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
|
||||
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
|
||||
if k == "max_vram_cache_size" and "vram" not in category_dict:
|
||||
parsed_config_dict["vram"] = v
|
||||
# autocast was removed in v4.0.1
|
||||
if k == "precision" and v == "autocast":
|
||||
parsed_config_dict["precision"] = "auto"
|
||||
if k == "conf_path":
|
||||
parsed_config_dict["legacy_models_yaml_path"] = v
|
||||
if k == "legacy_conf_dir":
|
||||
@@ -392,6 +395,28 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
|
||||
return config
|
||||
|
||||
|
||||
def migrate_v4_0_0_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
|
||||
"""Migrate v4.0.0 config dictionary to a current config object.
|
||||
|
||||
Args:
|
||||
config_dict: A dictionary of settings from a v4.0.0 config file.
|
||||
|
||||
Returns:
|
||||
An instance of `InvokeAIAppConfig` with the migrated settings.
|
||||
"""
|
||||
parsed_config_dict: dict[str, Any] = {}
|
||||
for k, v in config_dict.items():
|
||||
# autocast was removed from precision in v4.0.1
|
||||
if k == "precision" and v == "autocast":
|
||||
parsed_config_dict["precision"] = "auto"
|
||||
else:
|
||||
parsed_config_dict[k] = v
|
||||
if k == "schema_version":
|
||||
parsed_config_dict[k] = CONFIG_SCHEMA_VERSION
|
||||
config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
|
||||
return config
|
||||
|
||||
|
||||
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
||||
"""Load and migrate a config file to the latest version.
|
||||
|
||||
@@ -418,17 +443,21 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
||||
raise RuntimeError(f"Failed to load and migrate v3 config file {config_path}: {e}") from e
|
||||
migrated_config.write_file(config_path)
|
||||
return migrated_config
|
||||
else:
|
||||
# Attempt to load as a v4 config file
|
||||
try:
|
||||
# Meta is not included in the model fields, so we need to validate it separately
|
||||
config = InvokeAIAppConfig.model_validate(loaded_config_dict)
|
||||
assert (
|
||||
config.schema_version == CONFIG_SCHEMA_VERSION
|
||||
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}"
|
||||
return config
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e
|
||||
|
||||
if loaded_config_dict["schema_version"] == "4.0.0":
|
||||
loaded_config_dict = migrate_v4_0_0_config_dict(loaded_config_dict)
|
||||
loaded_config_dict.write_file(config_path)
|
||||
|
||||
# Attempt to load as a v4 config file
|
||||
try:
|
||||
# Meta is not included in the model fields, so we need to validate it separately
|
||||
config = InvokeAIAppConfig.model_validate(loaded_config_dict)
|
||||
assert (
|
||||
config.schema_version == CONFIG_SCHEMA_VERSION
|
||||
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}"
|
||||
return config
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
|
||||
@@ -13,6 +13,7 @@ from shutil import copyfile, copytree, move, rmtree
|
||||
from tempfile import mkdtemp
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
from huggingface_hub import HfFolder
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
@@ -42,7 +43,7 @@ from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMet
|
||||
from invokeai.backend.model_manager.probe import ModelProbe
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
from invokeai.backend.util import InvokeAILogger
|
||||
from invokeai.backend.util.devices import choose_precision, choose_torch_device
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
from .model_install_base import (
|
||||
MODEL_SOURCE_TO_TYPE_MAP,
|
||||
@@ -634,11 +635,10 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._next_job_id += 1
|
||||
return id
|
||||
|
||||
@staticmethod
|
||||
def _guess_variant() -> Optional[ModelRepoVariant]:
|
||||
def _guess_variant(self) -> Optional[ModelRepoVariant]:
|
||||
"""Guess the best HuggingFace variant type to download."""
|
||||
precision = choose_precision(choose_torch_device())
|
||||
return ModelRepoVariant.FP16 if precision == "float16" else None
|
||||
precision = TorchDevice.choose_torch_dtype()
|
||||
return ModelRepoVariant.FP16 if precision == torch.float16 else None
|
||||
|
||||
def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||
return ModelInstallJob(
|
||||
@@ -754,6 +754,8 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._download_cache[download_job.source] = install_job # matches a download job to an install job
|
||||
install_job.download_parts.add(download_job)
|
||||
|
||||
# only start the jobs once install_job.download_parts is fully populated
|
||||
for download_job in install_job.download_parts:
|
||||
self._download_queue.submit_download_job(
|
||||
download_job,
|
||||
on_start=self._download_started_callback,
|
||||
@@ -762,6 +764,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
on_error=self._download_error_callback,
|
||||
on_cancelled=self._download_cancelled_callback,
|
||||
)
|
||||
|
||||
return install_job
|
||||
|
||||
def _stat_size(self, path: Path) -> int:
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
"""Implementation of ModelManagerServiceBase."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from ..config import InvokeAIAppConfig
|
||||
@@ -67,7 +69,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
model_record_service: ModelRecordServiceBase,
|
||||
download_queue: DownloadQueueServiceBase,
|
||||
events: EventServiceBase,
|
||||
execution_device: torch.device = choose_torch_device(),
|
||||
execution_device: Optional[torch.device] = None,
|
||||
) -> Self:
|
||||
"""
|
||||
Construct the model manager service instance.
|
||||
@@ -82,7 +84,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
max_vram_cache_size=app_config.vram,
|
||||
lazy_offloading=app_config.lazy_offload,
|
||||
logger=logger,
|
||||
execution_device=execution_device,
|
||||
execution_device=execution_device or TorchDevice.choose_torch_device(),
|
||||
)
|
||||
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)
|
||||
loader = ModelLoadService(
|
||||
|
||||
@@ -13,7 +13,7 @@ from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
|
||||
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
config = get_config()
|
||||
@@ -56,7 +56,7 @@ class DepthAnythingDetector:
|
||||
def __init__(self) -> None:
|
||||
self.model = None
|
||||
self.model_size: Union[Literal["large", "base", "small"], None] = None
|
||||
self.device = choose_torch_device()
|
||||
self.device = TorchDevice.choose_torch_device()
|
||||
|
||||
def load_model(self, model_size: Literal["large", "base", "small"] = "small"):
|
||||
DEPTH_ANYTHING_MODEL_PATH = config.models_path / DEPTH_ANYTHING_MODELS[model_size]["local"]
|
||||
@@ -81,7 +81,7 @@ class DepthAnythingDetector:
|
||||
self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu"))
|
||||
self.model.eval()
|
||||
|
||||
self.model.to(choose_torch_device())
|
||||
self.model.to(self.device)
|
||||
return self.model
|
||||
|
||||
def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image:
|
||||
@@ -94,7 +94,7 @@ class DepthAnythingDetector:
|
||||
|
||||
image_height, image_width = np_image.shape[:2]
|
||||
np_image = transform({"image": np_image})["image"]
|
||||
tensor_image = torch.from_numpy(np_image).unsqueeze(0).to(choose_torch_device())
|
||||
tensor_image = torch.from_numpy(np_image).unsqueeze(0).to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
depth = self.model(tensor_image)
|
||||
|
||||
@@ -7,7 +7,7 @@ import onnxruntime as ort
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
from .onnxdet import inference_detector
|
||||
from .onnxpose import inference_pose
|
||||
@@ -28,9 +28,9 @@ config = get_config()
|
||||
|
||||
class Wholebody:
|
||||
def __init__(self):
|
||||
device = choose_torch_device()
|
||||
device = TorchDevice.choose_torch_device()
|
||||
|
||||
providers = ["CUDAExecutionProvider"] if device == "cuda" else ["CPUExecutionProvider"]
|
||||
providers = ["CUDAExecutionProvider"] if device.type == "cuda" else ["CPUExecutionProvider"]
|
||||
|
||||
DET_MODEL_PATH = config.models_path / DWPOSE_MODELS["yolox_l.onnx"]["local"]
|
||||
download_with_progress_bar("yolox_l.onnx", DWPOSE_MODELS["yolox_l.onnx"]["url"], DET_MODEL_PATH)
|
||||
|
||||
@@ -8,7 +8,7 @@ from PIL import Image
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
def norm_img(np_img):
|
||||
@@ -29,7 +29,7 @@ def load_jit_model(url_or_path, device):
|
||||
|
||||
class LaMA:
|
||||
def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
|
||||
device = choose_torch_device()
|
||||
device = TorchDevice.choose_torch_device()
|
||||
model_location = get_config().models_path / "core/misc/lama/lama.pt"
|
||||
|
||||
if not model_location.exists():
|
||||
|
||||
@@ -11,7 +11,7 @@ from cv2.typing import MatLike
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
"""
|
||||
Adapted from https://github.com/xinntao/Real-ESRGAN/blob/master/realesrgan/utils.py
|
||||
@@ -65,7 +65,7 @@ class RealESRGAN:
|
||||
self.pre_pad = pre_pad
|
||||
self.mod_scale: Optional[int] = None
|
||||
self.half = half
|
||||
self.device = choose_torch_device()
|
||||
self.device = TorchDevice.choose_torch_device()
|
||||
|
||||
loadnet = torch.load(model_path, map_location=torch.device("cpu"))
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ from transformers import AutoFeatureExtractor
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
CHECKER_PATH = "core/convert/stable-diffusion-safety-checker"
|
||||
@@ -51,7 +51,7 @@ class SafetyChecker:
|
||||
cls._load_safety_checker()
|
||||
if cls.safety_checker is None or cls.feature_extractor is None:
|
||||
return False
|
||||
device = choose_torch_device()
|
||||
device = TorchDevice.choose_torch_device()
|
||||
features = cls.feature_extractor([image], return_tensors="pt")
|
||||
features.to(device)
|
||||
cls.safety_checker.to(device)
|
||||
|
||||
@@ -18,7 +18,7 @@ from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoad
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data, calc_model_size_by_fs
|
||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
# TO DO: The loader is not thread safe!
|
||||
@@ -37,7 +37,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
self._logger = logger
|
||||
self._ram_cache = ram_cache
|
||||
self._convert_cache = convert_cache
|
||||
self._torch_dtype = torch_dtype(choose_torch_device())
|
||||
self._torch_dtype = TorchDevice.choose_torch_dtype()
|
||||
|
||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
"""
|
||||
|
||||
@@ -30,15 +30,12 @@ import torch
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel, SubModelType
|
||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase, ModelLockerBase
|
||||
from .model_locker import ModelLocker
|
||||
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
from torch import mps
|
||||
|
||||
# Maximum size of the cache, in gigs
|
||||
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
|
||||
DEFAULT_MAX_CACHE_SIZE = 6.0
|
||||
@@ -244,9 +241,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
|
||||
)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
|
||||
"""Move model into the indicated device.
|
||||
@@ -416,10 +411,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
self.stats.cleared = models_cleared
|
||||
gc.collect()
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
self.logger.debug(f"After making room: cached_models={len(self._cached_models)}")
|
||||
|
||||
def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None:
|
||||
|
||||
@@ -17,7 +17,7 @@ from diffusers.utils import logging as dlogging
|
||||
|
||||
from invokeai.app.services.model_install import ModelInstallServiceBase
|
||||
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
|
||||
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
from . import (
|
||||
AnyModelConfig,
|
||||
@@ -43,6 +43,7 @@ class ModelMerger(object):
|
||||
Initialize a ModelMerger object with the model installer.
|
||||
"""
|
||||
self._installer = installer
|
||||
self._dtype = TorchDevice.choose_torch_dtype()
|
||||
|
||||
def merge_diffusion_models(
|
||||
self,
|
||||
@@ -68,7 +69,7 @@ class ModelMerger(object):
|
||||
warnings.simplefilter("ignore")
|
||||
verbosity = dlogging.get_verbosity()
|
||||
dlogging.set_verbosity_error()
|
||||
dtype = torch.float16 if variant == "fp16" else torch_dtype(choose_torch_device())
|
||||
dtype = torch.float16 if variant == "fp16" else self._dtype
|
||||
|
||||
# Note that checkpoint_merger will not work with downloaded HuggingFace fp16 models
|
||||
# until upstream https://github.com/huggingface/diffusers/pull/6670 is merged and released.
|
||||
@@ -151,7 +152,7 @@ class ModelMerger(object):
|
||||
dump_path.mkdir(parents=True, exist_ok=True)
|
||||
dump_path = dump_path / merged_model_name
|
||||
|
||||
dtype = torch.float16 if variant == "fp16" else torch_dtype(choose_torch_device())
|
||||
dtype = torch.float16 if variant == "fp16" else self._dtype
|
||||
merged_pipe.save_pretrained(dump_path.as_posix(), safe_serialization=True, torch_dtype=dtype, variant=variant)
|
||||
|
||||
# register model and get its unique key
|
||||
|
||||
@@ -21,14 +21,11 @@ from pydantic import Field
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
IPAdapterData,
|
||||
TextConditioningData,
|
||||
)
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData
|
||||
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
|
||||
from invokeai.backend.util.attention import auto_detect_slice_size
|
||||
from invokeai.backend.util.devices import normalize_device
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -258,7 +255,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
if self.unet.device.type == "cpu" or self.unet.device.type == "mps":
|
||||
mem_free = psutil.virtual_memory().free
|
||||
elif self.unet.device.type == "cuda":
|
||||
mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.unet.device))
|
||||
mem_free, _ = torch.cuda.mem_get_info(TorchDevice.normalize(self.unet.device))
|
||||
else:
|
||||
raise ValueError(f"unrecognized device {self.unet.device}")
|
||||
# input tensor of [1, 4, h/8, w/8]
|
||||
@@ -394,8 +391,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
unet_attention_patcher = None
|
||||
self.use_ip_adapter = use_ip_adapter
|
||||
attn_ctx = nullcontext()
|
||||
|
||||
if use_ip_adapter or use_regional_prompting:
|
||||
ip_adapters = [ipa.ip_adapter_model for ipa in ip_adapter_data] if use_ip_adapter else None
|
||||
ip_adapters: Optional[List[UNetIPAdapterData]] = (
|
||||
[{"ip_adapter": ipa.ip_adapter_model, "target_blocks": ipa.target_blocks} for ipa in ip_adapter_data]
|
||||
if use_ip_adapter
|
||||
else None
|
||||
)
|
||||
unet_attention_patcher = UNetAttentionPatcher(ip_adapters)
|
||||
attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
|
||||
|
||||
|
||||
@@ -53,6 +53,7 @@ class IPAdapterData:
|
||||
ip_adapter_model: IPAdapter
|
||||
ip_adapter_conditioning: IPAdapterConditioningInfo
|
||||
mask: torch.Tensor
|
||||
target_blocks: List[str]
|
||||
|
||||
# Either a single weight applied to all steps, or a list of weights for each step.
|
||||
weight: Union[float, List[float]] = 1.0
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, cast
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -9,6 +10,12 @@ from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import Regiona
|
||||
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
|
||||
|
||||
|
||||
@dataclass
|
||||
class IPAdapterAttentionWeights:
|
||||
ip_adapter_weights: IPAttentionProcessorWeights
|
||||
skip: bool
|
||||
|
||||
|
||||
class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
"""A custom implementation of AttnProcessor2_0 that supports additional Invoke features.
|
||||
This implementation is based on
|
||||
@@ -20,7 +27,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ip_adapter_weights: Optional[list[IPAttentionProcessorWeights]] = None,
|
||||
ip_adapter_attention_weights: Optional[List[IPAdapterAttentionWeights]] = None,
|
||||
):
|
||||
"""Initialize a CustomAttnProcessor2_0.
|
||||
Note: Arguments that are the same for all attention layers are passed to __call__(). Arguments that are
|
||||
@@ -30,23 +37,22 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
for the i'th IP-Adapter.
|
||||
"""
|
||||
super().__init__()
|
||||
self._ip_adapter_weights = ip_adapter_weights
|
||||
|
||||
def _is_ip_adapter_enabled(self) -> bool:
|
||||
return self._ip_adapter_weights is not None
|
||||
self._ip_adapter_attention_weights = ip_adapter_attention_weights
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
# For regional prompting:
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
# For Regional Prompting:
|
||||
regional_prompt_data: Optional[RegionalPromptData] = None,
|
||||
percent_through: Optional[torch.FloatTensor] = None,
|
||||
percent_through: Optional[torch.Tensor] = None,
|
||||
# For IP-Adapter:
|
||||
regional_ip_data: Optional[RegionalIPData] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
"""Apply attention.
|
||||
Args:
|
||||
@@ -130,17 +136,19 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
|
||||
# Apply IP-Adapter conditioning.
|
||||
if is_cross_attention:
|
||||
if self._is_ip_adapter_enabled():
|
||||
if self._ip_adapter_attention_weights:
|
||||
assert regional_ip_data is not None
|
||||
ip_masks = regional_ip_data.get_masks(query_seq_len=query_seq_len)
|
||||
|
||||
assert (
|
||||
len(regional_ip_data.image_prompt_embeds)
|
||||
== len(self._ip_adapter_weights)
|
||||
== len(self._ip_adapter_attention_weights)
|
||||
== len(regional_ip_data.scales)
|
||||
== ip_masks.shape[1]
|
||||
)
|
||||
|
||||
for ipa_index, ipa_embed in enumerate(regional_ip_data.image_prompt_embeds):
|
||||
ipa_weights = self._ip_adapter_weights[ipa_index]
|
||||
ipa_weights = self._ip_adapter_attention_weights[ipa_index].ip_adapter_weights
|
||||
ipa_scale = regional_ip_data.scales[ipa_index]
|
||||
ip_mask = ip_masks[0, ipa_index, ...]
|
||||
|
||||
@@ -153,29 +161,33 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
|
||||
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
|
||||
|
||||
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
|
||||
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
|
||||
if not self._ip_adapter_attention_weights[ipa_index].skip:
|
||||
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
|
||||
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
|
||||
|
||||
# Expected ip_key and ip_value shape: (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
|
||||
# Expected ip_key and ip_value shape:
|
||||
# (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
|
||||
|
||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# Expected ip_key and ip_value shape: (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
|
||||
# Expected ip_key and ip_value shape:
|
||||
# (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
|
||||
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
ip_hidden_states = F.scaled_dot_product_attention(
|
||||
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
ip_hidden_states = F.scaled_dot_product_attention(
|
||||
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
|
||||
# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
|
||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(
|
||||
batch_size, -1, attn.heads * head_dim
|
||||
)
|
||||
|
||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
||||
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
||||
|
||||
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
|
||||
|
||||
hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask
|
||||
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
|
||||
hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask
|
||||
else:
|
||||
# If IP-Adapter is not enabled, then regional_ip_data should not be passed in.
|
||||
assert regional_ip_data is None
|
||||
@@ -188,11 +200,15 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
# End of unmodified block from AttnProcessor2_0
|
||||
|
||||
return hidden_states
|
||||
# casting torch.Tensor to torch.FloatTensor to avoid type issues
|
||||
return cast(torch.FloatTensor, hidden_states)
|
||||
|
||||
@@ -1,17 +1,25 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional
|
||||
from typing import List, Optional, TypedDict
|
||||
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
|
||||
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import (
|
||||
CustomAttnProcessor2_0,
|
||||
IPAdapterAttentionWeights,
|
||||
)
|
||||
|
||||
|
||||
class UNetIPAdapterData(TypedDict):
|
||||
ip_adapter: IPAdapter
|
||||
target_blocks: List[str]
|
||||
|
||||
|
||||
class UNetAttentionPatcher:
|
||||
"""A class for patching a UNet with CustomAttnProcessor2_0 attention layers."""
|
||||
|
||||
def __init__(self, ip_adapters: Optional[list[IPAdapter]]):
|
||||
self._ip_adapters = ip_adapters
|
||||
def __init__(self, ip_adapter_data: Optional[List[UNetIPAdapterData]]):
|
||||
self._ip_adapters = ip_adapter_data
|
||||
|
||||
def _prepare_attention_processors(self, unet: UNet2DConditionModel):
|
||||
"""Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention
|
||||
@@ -26,9 +34,22 @@ class UNetAttentionPatcher:
|
||||
attn_procs[name] = CustomAttnProcessor2_0()
|
||||
else:
|
||||
# Collect the weights from each IP Adapter for the idx'th attention processor.
|
||||
attn_procs[name] = CustomAttnProcessor2_0(
|
||||
[ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in self._ip_adapters],
|
||||
)
|
||||
ip_adapter_attention_weights_collection: list[IPAdapterAttentionWeights] = []
|
||||
|
||||
for ip_adapter in self._ip_adapters:
|
||||
ip_adapter_weights = ip_adapter["ip_adapter"].attn_weights.get_attention_processor_weights(idx)
|
||||
skip = True
|
||||
for block in ip_adapter["target_blocks"]:
|
||||
if block in name:
|
||||
skip = False
|
||||
break
|
||||
ip_adapter_attention_weights: IPAdapterAttentionWeights = IPAdapterAttentionWeights(
|
||||
ip_adapter_weights=ip_adapter_weights, skip=skip
|
||||
)
|
||||
ip_adapter_attention_weights_collection.append(ip_adapter_attention_weights)
|
||||
|
||||
attn_procs[name] = CustomAttnProcessor2_0(ip_adapter_attention_weights_collection)
|
||||
|
||||
return attn_procs
|
||||
|
||||
@contextmanager
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
Initialization file for invokeai.backend.util
|
||||
"""
|
||||
|
||||
from .devices import choose_precision, choose_torch_device
|
||||
from .logging import InvokeAILogger
|
||||
from .util import GIG, Chdir, directory_size
|
||||
|
||||
@@ -11,6 +10,4 @@ __all__ = [
|
||||
"directory_size",
|
||||
"Chdir",
|
||||
"InvokeAILogger",
|
||||
"choose_precision",
|
||||
"choose_torch_device",
|
||||
]
|
||||
|
||||
@@ -1,89 +1,110 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import nullcontext
|
||||
from typing import Literal, Optional, Union
|
||||
from typing import Dict, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import autocast
|
||||
from deprecated import deprecated
|
||||
|
||||
from invokeai.app.services.config.config_default import PRECISION, get_config
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
|
||||
# legacy APIs
|
||||
TorchPrecisionNames = Literal["float32", "float16", "bfloat16"]
|
||||
CPU_DEVICE = torch.device("cpu")
|
||||
CUDA_DEVICE = torch.device("cuda")
|
||||
MPS_DEVICE = torch.device("mps")
|
||||
|
||||
|
||||
@deprecated("Use TorchDevice.choose_torch_dtype() instead.") # type: ignore
|
||||
def choose_precision(device: torch.device) -> TorchPrecisionNames:
|
||||
"""Return the string representation of the recommended torch device."""
|
||||
torch_dtype = TorchDevice.choose_torch_dtype(device)
|
||||
return PRECISION_TO_NAME[torch_dtype]
|
||||
|
||||
|
||||
@deprecated("Use TorchDevice.choose_torch_device() instead.") # type: ignore
|
||||
def choose_torch_device() -> torch.device:
|
||||
"""Convenience routine for guessing which GPU device to run model on"""
|
||||
config = get_config()
|
||||
if config.device == "auto":
|
||||
if torch.cuda.is_available():
|
||||
return torch.device("cuda")
|
||||
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
return torch.device("mps")
|
||||
"""Return the torch.device to use for accelerated inference."""
|
||||
return TorchDevice.choose_torch_device()
|
||||
|
||||
|
||||
@deprecated("Use TorchDevice.choose_torch_dtype() instead.") # type: ignore
|
||||
def torch_dtype(device: torch.device) -> torch.dtype:
|
||||
"""Return the torch precision for the recommended torch device."""
|
||||
return TorchDevice.choose_torch_dtype(device)
|
||||
|
||||
|
||||
NAME_TO_PRECISION: Dict[TorchPrecisionNames, torch.dtype] = {
|
||||
"float32": torch.float32,
|
||||
"float16": torch.float16,
|
||||
"bfloat16": torch.bfloat16,
|
||||
}
|
||||
PRECISION_TO_NAME: Dict[torch.dtype, TorchPrecisionNames] = {v: k for k, v in NAME_TO_PRECISION.items()}
|
||||
|
||||
|
||||
class TorchDevice:
|
||||
"""Abstraction layer for torch devices."""
|
||||
|
||||
@classmethod
|
||||
def choose_torch_device(cls) -> torch.device:
|
||||
"""Return the torch.device to use for accelerated inference."""
|
||||
app_config = get_config()
|
||||
if app_config.device != "auto":
|
||||
device = torch.device(app_config.device)
|
||||
elif torch.cuda.is_available():
|
||||
device = CUDA_DEVICE
|
||||
elif torch.backends.mps.is_available():
|
||||
device = MPS_DEVICE
|
||||
else:
|
||||
return CPU_DEVICE
|
||||
else:
|
||||
return torch.device(config.device)
|
||||
device = CPU_DEVICE
|
||||
return cls.normalize(device)
|
||||
|
||||
@classmethod
|
||||
def choose_torch_dtype(cls, device: Optional[torch.device] = None) -> torch.dtype:
|
||||
"""Return the precision to use for accelerated inference."""
|
||||
device = device or cls.choose_torch_device()
|
||||
config = get_config()
|
||||
if device.type == "cuda" and torch.cuda.is_available():
|
||||
device_name = torch.cuda.get_device_name(device)
|
||||
if "GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name:
|
||||
# These GPUs have limited support for float16
|
||||
return cls._to_dtype("float32")
|
||||
elif config.precision == "auto":
|
||||
# Default to float16 for CUDA devices
|
||||
return cls._to_dtype("float16")
|
||||
else:
|
||||
# Use the user-defined precision
|
||||
return cls._to_dtype(config.precision)
|
||||
|
||||
def get_torch_device_name() -> str:
|
||||
device = choose_torch_device()
|
||||
return torch.cuda.get_device_name(device) if device.type == "cuda" else device.type.upper()
|
||||
elif device.type == "mps" and torch.backends.mps.is_available():
|
||||
if config.precision == "auto":
|
||||
# Default to float16 for MPS devices
|
||||
return cls._to_dtype("float16")
|
||||
else:
|
||||
# Use the user-defined precision
|
||||
return cls._to_dtype(config.precision)
|
||||
# CPU / safe fallback
|
||||
return cls._to_dtype("float32")
|
||||
|
||||
@classmethod
|
||||
def get_torch_device_name(cls) -> str:
|
||||
"""Return the device name for the current torch device."""
|
||||
device = cls.choose_torch_device()
|
||||
return torch.cuda.get_device_name(device) if device.type == "cuda" else device.type.upper()
|
||||
|
||||
def choose_precision(device: torch.device) -> Literal["float32", "float16", "bfloat16"]:
|
||||
"""Return an appropriate precision for the given torch device."""
|
||||
app_config = get_config()
|
||||
if device.type == "cuda":
|
||||
device_name = torch.cuda.get_device_name(device)
|
||||
if "GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name:
|
||||
# These GPUs have limited support for float16
|
||||
return "float32"
|
||||
elif app_config.precision == "auto" or app_config.precision == "autocast":
|
||||
# Default to float16 for CUDA devices
|
||||
return "float16"
|
||||
else:
|
||||
# Use the user-defined precision
|
||||
return app_config.precision
|
||||
elif device.type == "mps":
|
||||
if app_config.precision == "auto" or app_config.precision == "autocast":
|
||||
# Default to float16 for MPS devices
|
||||
return "float16"
|
||||
else:
|
||||
# Use the user-defined precision
|
||||
return app_config.precision
|
||||
# CPU / safe fallback
|
||||
return "float32"
|
||||
|
||||
|
||||
def torch_dtype(device: Optional[torch.device] = None) -> torch.dtype:
|
||||
device = device or choose_torch_device()
|
||||
precision = choose_precision(device)
|
||||
if precision == "float16":
|
||||
return torch.float16
|
||||
if precision == "bfloat16":
|
||||
return torch.bfloat16
|
||||
else:
|
||||
# "auto", "autocast", "float32"
|
||||
return torch.float32
|
||||
|
||||
|
||||
def choose_autocast(precision: PRECISION):
|
||||
"""Returns an autocast context or nullcontext for the given precision string"""
|
||||
# float16 currently requires autocast to avoid errors like:
|
||||
# 'expected scalar type Half but found Float'
|
||||
if precision == "autocast" or precision == "float16":
|
||||
return autocast
|
||||
return nullcontext
|
||||
|
||||
|
||||
def normalize_device(device: Union[str, torch.device]) -> torch.device:
|
||||
"""Ensure device has a device index defined, if appropriate."""
|
||||
device = torch.device(device)
|
||||
if device.index is None:
|
||||
# cuda might be the only torch backend that currently uses the device index?
|
||||
# I don't see anything like `current_device` for cpu or mps.
|
||||
if device.type == "cuda":
|
||||
@classmethod
|
||||
def normalize(cls, device: Union[str, torch.device]) -> torch.device:
|
||||
"""Add the device index to CUDA devices."""
|
||||
device = torch.device(device)
|
||||
if device.index is None and device.type == "cuda" and torch.cuda.is_available():
|
||||
device = torch.device(device.type, torch.cuda.current_device())
|
||||
return device
|
||||
return device
|
||||
|
||||
@classmethod
|
||||
def empty_cache(cls) -> None:
|
||||
"""Clear the GPU device cache."""
|
||||
if torch.backends.mps.is_available():
|
||||
torch.mps.empty_cache()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@classmethod
|
||||
def _to_dtype(cls, precision_name: TorchPrecisionNames) -> torch.dtype:
|
||||
return NAME_TO_PRECISION[precision_name]
|
||||
|
||||
@@ -85,7 +85,8 @@
|
||||
"loadMore": "Mehr laden",
|
||||
"noImagesInGallery": "Keine Bilder in der Galerie",
|
||||
"loading": "Lade",
|
||||
"deleteImage": "Lösche Bild",
|
||||
"deleteImage_one": "Lösche Bild",
|
||||
"deleteImage_other": "",
|
||||
"copy": "Kopieren",
|
||||
"download": "Runterladen",
|
||||
"setCurrentImage": "Setze aktuelle Bild",
|
||||
|
||||
@@ -69,6 +69,7 @@
|
||||
"auto": "Auto",
|
||||
"back": "Back",
|
||||
"batch": "Batch Manager",
|
||||
"beta": "Beta",
|
||||
"cancel": "Cancel",
|
||||
"copy": "Copy",
|
||||
"copyError": "$t(gallery.copy) Error",
|
||||
@@ -213,6 +214,10 @@
|
||||
"resize": "Resize",
|
||||
"resizeSimple": "Resize (Simple)",
|
||||
"resizeMode": "Resize Mode",
|
||||
"ipAdapterMethod": "Method",
|
||||
"full": "Full",
|
||||
"style": "Style Only",
|
||||
"composition": "Composition Only",
|
||||
"safe": "Safe",
|
||||
"saveControlImage": "Save Control Image",
|
||||
"scribble": "scribble",
|
||||
@@ -770,6 +775,8 @@
|
||||
"float": "Float",
|
||||
"fullyContainNodes": "Fully Contain Nodes to Select",
|
||||
"fullyContainNodesHelp": "Nodes must be fully inside the selection box to be selected",
|
||||
"showEdgeLabels": "Show Edge Labels",
|
||||
"showEdgeLabelsHelp": "Show labels on edges, indicating the connected nodes",
|
||||
"hideLegendNodes": "Hide Field Type Legend",
|
||||
"hideMinimapnodes": "Hide MiniMap",
|
||||
"inputMayOnlyHaveOneConnection": "Input may only have one connection",
|
||||
|
||||
@@ -33,7 +33,9 @@
|
||||
"autoSwitchNewImages": "Auto seleccionar Imágenes nuevas",
|
||||
"loadMore": "Cargar más",
|
||||
"noImagesInGallery": "No hay imágenes para mostrar",
|
||||
"deleteImage": "Eliminar Imagen",
|
||||
"deleteImage_one": "Eliminar Imagen",
|
||||
"deleteImage_many": "",
|
||||
"deleteImage_other": "",
|
||||
"deleteImageBin": "Las imágenes eliminadas se enviarán a la papelera de tu sistema operativo.",
|
||||
"deleteImagePermanent": "Las imágenes eliminadas no se pueden restaurar.",
|
||||
"assets": "Activos",
|
||||
|
||||
@@ -82,7 +82,9 @@
|
||||
"autoSwitchNewImages": "Passaggio automatico a nuove immagini",
|
||||
"loadMore": "Carica altro",
|
||||
"noImagesInGallery": "Nessuna immagine da visualizzare",
|
||||
"deleteImage": "Elimina l'immagine",
|
||||
"deleteImage_one": "Elimina l'immagine",
|
||||
"deleteImage_many": "Elimina {{count}} immagini",
|
||||
"deleteImage_other": "Elimina {{count}} immagini",
|
||||
"deleteImagePermanent": "Le immagini eliminate non possono essere ripristinate.",
|
||||
"deleteImageBin": "Le immagini eliminate verranno spostate nel cestino del tuo sistema operativo.",
|
||||
"assets": "Risorse",
|
||||
|
||||
@@ -90,7 +90,7 @@
|
||||
"problemDeletingImages": "画像の削除中に問題が発生",
|
||||
"drop": "ドロップ",
|
||||
"dropOrUpload": "$t(gallery.drop) またはアップロード",
|
||||
"deleteImage": "画像を削除",
|
||||
"deleteImage_other": "画像を削除",
|
||||
"deleteImageBin": "削除された画像はOSのゴミ箱に送られます。",
|
||||
"deleteImagePermanent": "削除された画像は復元できません。",
|
||||
"download": "ダウンロード",
|
||||
|
||||
@@ -82,7 +82,7 @@
|
||||
"drop": "드랍",
|
||||
"problemDeletingImages": "이미지 삭제 중 발생한 문제",
|
||||
"downloadSelection": "선택 항목 다운로드",
|
||||
"deleteImage": "이미지 삭제",
|
||||
"deleteImage_other": "이미지 삭제",
|
||||
"currentlyInUse": "이 이미지는 현재 다음 기능에서 사용되고 있습니다:",
|
||||
"dropOrUpload": "$t(gallery.drop) 또는 업로드",
|
||||
"copy": "복사",
|
||||
|
||||
@@ -42,7 +42,8 @@
|
||||
"autoSwitchNewImages": "Wissel autom. naar nieuwe afbeeldingen",
|
||||
"loadMore": "Laad meer",
|
||||
"noImagesInGallery": "Geen afbeeldingen om te tonen",
|
||||
"deleteImage": "Verwijder afbeelding",
|
||||
"deleteImage_one": "Verwijder afbeelding",
|
||||
"deleteImage_other": "",
|
||||
"deleteImageBin": "Verwijderde afbeeldingen worden naar de prullenbak van je besturingssysteem gestuurd.",
|
||||
"deleteImagePermanent": "Verwijderde afbeeldingen kunnen niet worden hersteld.",
|
||||
"assets": "Eigen onderdelen",
|
||||
|
||||
@@ -86,7 +86,9 @@
|
||||
"noImagesInGallery": "Изображений нет",
|
||||
"deleteImagePermanent": "Удаленные изображения невозможно восстановить.",
|
||||
"deleteImageBin": "Удаленные изображения будут отправлены в корзину вашей операционной системы.",
|
||||
"deleteImage": "Удалить изображение",
|
||||
"deleteImage_one": "Удалить изображение",
|
||||
"deleteImage_few": "",
|
||||
"deleteImage_many": "",
|
||||
"assets": "Ресурсы",
|
||||
"autoAssignBoardOnClick": "Авто-назначение доски по клику",
|
||||
"deleteSelection": "Удалить выделенное",
|
||||
|
||||
@@ -298,7 +298,8 @@
|
||||
"noImagesInGallery": "Gösterilecek Görsel Yok",
|
||||
"autoSwitchNewImages": "Yeni Görseli Biter Bitmez Gör",
|
||||
"currentlyInUse": "Bu görsel şurada kullanımda:",
|
||||
"deleteImage": "Görseli Sil",
|
||||
"deleteImage_one": "Görseli Sil",
|
||||
"deleteImage_other": "",
|
||||
"loadMore": "Daha Getir",
|
||||
"setCurrentImage": "Çalışma Görseli Yap",
|
||||
"unableToLoad": "Galeri Yüklenemedi",
|
||||
|
||||
@@ -78,7 +78,7 @@
|
||||
"autoSwitchNewImages": "自动切换到新图像",
|
||||
"loadMore": "加载更多",
|
||||
"noImagesInGallery": "无图像可用于显示",
|
||||
"deleteImage": "删除图片",
|
||||
"deleteImage_other": "删除图片",
|
||||
"deleteImageBin": "被删除的图片会发送到你操作系统的回收站。",
|
||||
"deleteImagePermanent": "删除的图片无法被恢复。",
|
||||
"assets": "素材",
|
||||
|
||||
@@ -9,7 +9,7 @@ import { useHotkeys } from 'react-hotkeys-hook';
|
||||
|
||||
export const useGlobalHotkeys = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const isModelManagerEnabled = useFeatureStatus('modelManager').isFeatureEnabled;
|
||||
const isModelManagerEnabled = useFeatureStatus('modelManager');
|
||||
const { queueBack, isDisabled: isDisabledQueueBack, isLoading: isLoadingQueueBack } = useQueueBack();
|
||||
|
||||
useHotkeys(
|
||||
|
||||
@@ -21,6 +21,7 @@ import ControlAdapterShouldAutoConfig from './ControlAdapterShouldAutoConfig';
|
||||
import ControlNetCanvasImageImports from './imports/ControlNetCanvasImageImports';
|
||||
import { ParamControlAdapterBeginEnd } from './parameters/ParamControlAdapterBeginEnd';
|
||||
import ParamControlAdapterControlMode from './parameters/ParamControlAdapterControlMode';
|
||||
import ParamControlAdapterIPMethod from './parameters/ParamControlAdapterIPMethod';
|
||||
import ParamControlAdapterProcessorSelect from './parameters/ParamControlAdapterProcessorSelect';
|
||||
import ParamControlAdapterResizeMode from './parameters/ParamControlAdapterResizeMode';
|
||||
import ParamControlAdapterWeight from './parameters/ParamControlAdapterWeight';
|
||||
@@ -111,7 +112,8 @@ const ControlAdapterConfig = (props: { id: string; number: number }) => {
|
||||
|
||||
<Flex w="full" flexDir="column" gap={4}>
|
||||
<Flex gap={8} w="full" alignItems="center">
|
||||
<Flex flexDir="column" gap={2} h={32} w="full">
|
||||
<Flex flexDir="column" gap={4} h={controlAdapterType === 'ip_adapter' ? 40 : 32} w="full">
|
||||
<ParamControlAdapterIPMethod id={id} />
|
||||
<ParamControlAdapterWeight id={id} />
|
||||
<ParamControlAdapterBeginEnd id={id} />
|
||||
</Flex>
|
||||
|
||||
@@ -0,0 +1,63 @@
|
||||
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
|
||||
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { useControlAdapterIPMethod } from 'features/controlAdapters/hooks/useControlAdapterIPMethod';
|
||||
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
|
||||
import { controlAdapterIPMethodChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import type { IPMethod } from 'features/controlAdapters/store/types';
|
||||
import { isIPMethod } from 'features/controlAdapters/store/types';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
type Props = {
|
||||
id: string;
|
||||
};
|
||||
|
||||
const ParamControlAdapterIPMethod = ({ id }: Props) => {
|
||||
const isEnabled = useControlAdapterIsEnabled(id);
|
||||
const method = useControlAdapterIPMethod(id);
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const options: { label: string; value: IPMethod }[] = useMemo(
|
||||
() => [
|
||||
{ label: t('controlnet.full'), value: 'full' },
|
||||
{ label: `${t('controlnet.style')} (${t('common.beta')})`, value: 'style' },
|
||||
{ label: `${t('controlnet.composition')} (${t('common.beta')})`, value: 'composition' },
|
||||
],
|
||||
[t]
|
||||
);
|
||||
|
||||
const handleIPMethodChanged = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
if (!isIPMethod(v?.value)) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
controlAdapterIPMethodChanged({
|
||||
id,
|
||||
method: v.value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[id, dispatch]
|
||||
);
|
||||
|
||||
const value = useMemo(() => options.find((o) => o.value === method), [options, method]);
|
||||
|
||||
if (!method) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<FormControl>
|
||||
<InformationalPopover feature="controlNetResizeMode">
|
||||
<FormLabel>{t('controlnet.ipAdapterMethod')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<Combobox value={value} options={options} isDisabled={!isEnabled} onChange={handleIPMethodChanged} />
|
||||
</FormControl>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamControlAdapterIPMethod);
|
||||
@@ -0,0 +1,24 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import {
|
||||
selectControlAdapterById,
|
||||
selectControlAdaptersSlice,
|
||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useControlAdapterIPMethod = (id: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectControlAdaptersSlice, (controlAdapters) => {
|
||||
const cn = selectControlAdapterById(controlAdapters, id);
|
||||
if (cn && cn?.type === 'ip_adapter') {
|
||||
return cn.method;
|
||||
}
|
||||
}),
|
||||
[id]
|
||||
);
|
||||
|
||||
const method = useAppSelector(selector);
|
||||
|
||||
return method;
|
||||
};
|
||||
@@ -21,6 +21,7 @@ import type {
|
||||
ControlAdapterType,
|
||||
ControlMode,
|
||||
ControlNetConfig,
|
||||
IPMethod,
|
||||
RequiredControlAdapterProcessorNode,
|
||||
ResizeMode,
|
||||
T2IAdapterConfig,
|
||||
@@ -245,6 +246,10 @@ export const controlAdaptersSlice = createSlice({
|
||||
}
|
||||
caAdapter.updateOne(state, { id, changes: { controlMode } });
|
||||
},
|
||||
controlAdapterIPMethodChanged: (state, action: PayloadAction<{ id: string; method: IPMethod }>) => {
|
||||
const { id, method } = action.payload;
|
||||
caAdapter.updateOne(state, { id, changes: { method } });
|
||||
},
|
||||
controlAdapterCLIPVisionModelChanged: (
|
||||
state,
|
||||
action: PayloadAction<{ id: string; clipVisionModel: CLIPVisionModel }>
|
||||
@@ -390,6 +395,7 @@ export const {
|
||||
controlAdapterIsEnabledChanged,
|
||||
controlAdapterModelChanged,
|
||||
controlAdapterCLIPVisionModelChanged,
|
||||
controlAdapterIPMethodChanged,
|
||||
controlAdapterWeightChanged,
|
||||
controlAdapterBeginStepPctChanged,
|
||||
controlAdapterEndStepPctChanged,
|
||||
|
||||
@@ -210,6 +210,10 @@ const zResizeMode = z.enum(['just_resize', 'crop_resize', 'fill_resize', 'just_r
|
||||
export type ResizeMode = z.infer<typeof zResizeMode>;
|
||||
export const isResizeMode = (v: unknown): v is ResizeMode => zResizeMode.safeParse(v).success;
|
||||
|
||||
const zIPMethod = z.enum(['full', 'style', 'composition']);
|
||||
export type IPMethod = z.infer<typeof zIPMethod>;
|
||||
export const isIPMethod = (v: unknown): v is IPMethod => zIPMethod.safeParse(v).success;
|
||||
|
||||
export type ControlNetConfig = {
|
||||
type: 'controlnet';
|
||||
id: string;
|
||||
@@ -253,6 +257,7 @@ export type IPAdapterConfig = {
|
||||
model: ParameterIPAdapterModel | null;
|
||||
clipVisionModel: CLIPVisionModel;
|
||||
weight: number;
|
||||
method: IPMethod;
|
||||
beginStepPct: number;
|
||||
endStepPct: number;
|
||||
};
|
||||
|
||||
@@ -46,6 +46,7 @@ export const initialIPAdapter: Omit<IPAdapterConfig, 'id'> = {
|
||||
isEnabled: true,
|
||||
controlImage: null,
|
||||
model: null,
|
||||
method: 'full',
|
||||
clipVisionModel: 'ViT-H',
|
||||
weight: 1,
|
||||
beginStepPct: 0,
|
||||
|
||||
@@ -32,7 +32,7 @@ const BoardContextMenu = ({ board, board_id, setBoardToDelete, children }: Props
|
||||
|
||||
const isSelectedForAutoAdd = useAppSelector(selectIsSelectedForAutoAdd);
|
||||
const boardName = useBoardName(board_id);
|
||||
const isBulkDownloadEnabled = useFeatureStatus('bulkDownload').isFeatureEnabled;
|
||||
const isBulkDownloadEnabled = useFeatureStatus('bulkDownload');
|
||||
|
||||
const [bulkDownload] = useBulkDownloadImagesMutation();
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ const CurrentImageButtons = () => {
|
||||
const selection = useAppSelector((s) => s.gallery.selection);
|
||||
const shouldDisableToolbarButtons = useAppSelector(selectShouldDisableToolbarButtons);
|
||||
|
||||
const isUpscalingEnabled = useFeatureStatus('upscaling').isFeatureEnabled;
|
||||
const isUpscalingEnabled = useFeatureStatus('upscaling');
|
||||
const isQueueMutationInProgress = useIsQueueMutationInProgress();
|
||||
const toaster = useAppToaster();
|
||||
const { t } = useTranslation();
|
||||
|
||||
@@ -20,7 +20,7 @@ const MultipleSelectionMenuItems = () => {
|
||||
const selection = useAppSelector((s) => s.gallery.selection);
|
||||
const customStarUi = useStore($customStarUI);
|
||||
|
||||
const isBulkDownloadEnabled = useFeatureStatus('bulkDownload').isFeatureEnabled;
|
||||
const isBulkDownloadEnabled = useFeatureStatus('bulkDownload');
|
||||
|
||||
const [starImages] = useStarImagesMutation();
|
||||
const [unstarImages] = useUnstarImagesMutation();
|
||||
|
||||
@@ -45,7 +45,7 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const toaster = useAppToaster();
|
||||
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
|
||||
const isCanvasEnabled = useFeatureStatus('unifiedCanvas');
|
||||
const customStarUi = useStore($customStarUI);
|
||||
const { downloadImage } = useDownloadImage();
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ export const useMultiselect = (imageDTO?: ImageDTO) => {
|
||||
[imageDTO?.image_name]
|
||||
);
|
||||
const isSelected = useAppSelector(selectIsSelected);
|
||||
const isMultiSelectEnabled = useFeatureStatus('multiselect').isFeatureEnabled;
|
||||
const isMultiSelectEnabled = useFeatureStatus('multiselect');
|
||||
|
||||
const handleClick = useCallback(
|
||||
(e: MouseEvent<HTMLDivElement>) => {
|
||||
|
||||
@@ -8,7 +8,7 @@ import ParamHrfStrength from './ParamHrfStrength';
|
||||
import ParamHrfToggle from './ParamHrfToggle';
|
||||
|
||||
export const HrfSettings = memo(() => {
|
||||
const isHRFFeatureEnabled = useFeatureStatus('hrf').isFeatureEnabled;
|
||||
const isHRFFeatureEnabled = useFeatureStatus('hrf');
|
||||
const hrfEnabled = useAppSelector((s) => s.hrf.hrfEnabled);
|
||||
|
||||
if (!isHRFFeatureEnabled) {
|
||||
|
||||
@@ -386,6 +386,10 @@ const parseIPAdapter: MetadataParseFunc<IPAdapterConfigMetadata> = async (metada
|
||||
.nullish()
|
||||
.catch(null)
|
||||
.parse(await getProperty(metadataItem, 'weight'));
|
||||
const method = zIPAdapterField.shape.method
|
||||
.nullish()
|
||||
.catch(null)
|
||||
.parse(await getProperty(metadataItem, 'method'));
|
||||
const begin_step_percent = zIPAdapterField.shape.begin_step_percent
|
||||
.nullish()
|
||||
.catch(null)
|
||||
@@ -403,6 +407,7 @@ const parseIPAdapter: MetadataParseFunc<IPAdapterConfigMetadata> = async (metada
|
||||
clipVisionModel: 'ViT-H',
|
||||
controlImage: image?.image_name ?? null,
|
||||
weight: weight ?? initialIPAdapter.weight,
|
||||
method: method ?? initialIPAdapter.method,
|
||||
beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct,
|
||||
endStepPct: end_step_percent ?? initialIPAdapter.endStepPct,
|
||||
};
|
||||
|
||||
@@ -10,7 +10,7 @@ const TOAST_ID = 'starterModels';
|
||||
|
||||
export const useStarterModelsToast = () => {
|
||||
const { t } = useTranslation();
|
||||
const isEnabled = useFeatureStatus('starterModels').isFeatureEnabled;
|
||||
const isEnabled = useFeatureStatus('starterModels');
|
||||
const [didToast, setDidToast] = useState(false);
|
||||
const [mainModels, { data }] = useMainModels();
|
||||
const toast = useToast();
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import { Flex, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import type { CSSProperties } from 'react';
|
||||
import { memo, useMemo } from 'react';
|
||||
import type { EdgeProps } from 'reactflow';
|
||||
import { BaseEdge, getBezierPath } from 'reactflow';
|
||||
import { BaseEdge, EdgeLabelRenderer, getBezierPath } from 'reactflow';
|
||||
|
||||
import { makeEdgeSelector } from './util/makeEdgeSelector';
|
||||
|
||||
@@ -25,9 +26,10 @@ const InvocationDefaultEdge = ({
|
||||
[source, sourceHandleId, target, targetHandleId, selected]
|
||||
);
|
||||
|
||||
const { isSelected, shouldAnimate, stroke } = useAppSelector(selector);
|
||||
const { isSelected, shouldAnimate, stroke, label } = useAppSelector(selector);
|
||||
const shouldShowEdgeLabels = useAppSelector((s) => s.nodes.shouldShowEdgeLabels);
|
||||
|
||||
const [edgePath] = getBezierPath({
|
||||
const [edgePath, labelX, labelY] = getBezierPath({
|
||||
sourceX,
|
||||
sourceY,
|
||||
sourcePosition,
|
||||
@@ -47,7 +49,33 @@ const InvocationDefaultEdge = ({
|
||||
[isSelected, shouldAnimate, stroke]
|
||||
);
|
||||
|
||||
return <BaseEdge path={edgePath} markerEnd={markerEnd} style={edgeStyles} />;
|
||||
return (
|
||||
<>
|
||||
<BaseEdge path={edgePath} markerEnd={markerEnd} style={edgeStyles} />
|
||||
{label && shouldShowEdgeLabels && (
|
||||
<EdgeLabelRenderer>
|
||||
<Flex
|
||||
className="nodrag nopan"
|
||||
pointerEvents="all"
|
||||
position="absolute"
|
||||
transform={`translate(-50%, -50%) translate(${labelX}px,${labelY}px)`}
|
||||
bg="base.800"
|
||||
borderRadius="base"
|
||||
borderWidth={1}
|
||||
borderColor={isSelected ? 'undefined' : 'transparent'}
|
||||
opacity={isSelected ? 1 : 0.5}
|
||||
py={1}
|
||||
px={3}
|
||||
shadow="md"
|
||||
>
|
||||
<Text size="sm" fontWeight="semibold" color={isSelected ? 'base.100' : 'base.300'}>
|
||||
{label}
|
||||
</Text>
|
||||
</Flex>
|
||||
</EdgeLabelRenderer>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(InvocationDefaultEdge);
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectFieldOutputTemplate } from 'features/nodes/store/selectors';
|
||||
import { selectFieldOutputTemplate, selectNodeTemplate } from 'features/nodes/store/selectors';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
|
||||
import { getFieldColor } from './getEdgeColor';
|
||||
@@ -10,6 +10,7 @@ const defaultReturnValue = {
|
||||
isSelected: false,
|
||||
shouldAnimate: false,
|
||||
stroke: colorTokenToCssVar('base.500'),
|
||||
label: '',
|
||||
};
|
||||
|
||||
export const makeEdgeSelector = (
|
||||
@@ -19,25 +20,34 @@ export const makeEdgeSelector = (
|
||||
targetHandleId: string | null | undefined,
|
||||
selected?: boolean
|
||||
) =>
|
||||
createMemoizedSelector(selectNodesSlice, (nodes): { isSelected: boolean; shouldAnimate: boolean; stroke: string } => {
|
||||
const sourceNode = nodes.nodes.find((node) => node.id === source);
|
||||
const targetNode = nodes.nodes.find((node) => node.id === target);
|
||||
createMemoizedSelector(
|
||||
selectNodesSlice,
|
||||
(nodes): { isSelected: boolean; shouldAnimate: boolean; stroke: string; label: string } => {
|
||||
const sourceNode = nodes.nodes.find((node) => node.id === source);
|
||||
const targetNode = nodes.nodes.find((node) => node.id === target);
|
||||
|
||||
const isInvocationToInvocationEdge = isInvocationNode(sourceNode) && isInvocationNode(targetNode);
|
||||
const isInvocationToInvocationEdge = isInvocationNode(sourceNode) && isInvocationNode(targetNode);
|
||||
|
||||
const isSelected = Boolean(sourceNode?.selected || targetNode?.selected || selected);
|
||||
if (!sourceNode || !sourceHandleId) {
|
||||
return defaultReturnValue;
|
||||
const isSelected = Boolean(sourceNode?.selected || targetNode?.selected || selected);
|
||||
if (!sourceNode || !sourceHandleId || !targetNode || !targetHandleId) {
|
||||
return defaultReturnValue;
|
||||
}
|
||||
|
||||
const outputFieldTemplate = selectFieldOutputTemplate(nodes, sourceNode.id, sourceHandleId);
|
||||
const sourceType = isInvocationToInvocationEdge ? outputFieldTemplate?.type : undefined;
|
||||
|
||||
const stroke = sourceType && nodes.shouldColorEdges ? getFieldColor(sourceType) : colorTokenToCssVar('base.500');
|
||||
|
||||
const sourceNodeTemplate = selectNodeTemplate(nodes, sourceNode.id);
|
||||
const targetNodeTemplate = selectNodeTemplate(nodes, targetNode.id);
|
||||
|
||||
const label = `${sourceNodeTemplate?.title || sourceNode.data?.label} -> ${targetNodeTemplate?.title || targetNode.data?.label}`;
|
||||
|
||||
return {
|
||||
isSelected,
|
||||
shouldAnimate: nodes.shouldAnimateEdges && isSelected,
|
||||
stroke,
|
||||
label,
|
||||
};
|
||||
}
|
||||
|
||||
const outputFieldTemplate = selectFieldOutputTemplate(nodes, sourceNode.id, sourceHandleId);
|
||||
const sourceType = isInvocationToInvocationEdge ? outputFieldTemplate?.type : undefined;
|
||||
|
||||
const stroke = sourceType && nodes.shouldColorEdges ? getFieldColor(sourceType) : colorTokenToCssVar('base.500');
|
||||
|
||||
return {
|
||||
isSelected,
|
||||
shouldAnimate: nodes.shouldAnimateEdges && isSelected,
|
||||
stroke,
|
||||
};
|
||||
});
|
||||
);
|
||||
|
||||
@@ -16,7 +16,7 @@ const props: ChakraProps = { w: 'unset' };
|
||||
|
||||
const InvocationNodeFooter = ({ nodeId }: Props) => {
|
||||
const hasImageOutput = useHasImageOutput(nodeId);
|
||||
const isCacheEnabled = useFeatureStatus('invocationCache').isFeatureEnabled;
|
||||
const isCacheEnabled = useFeatureStatus('invocationCache');
|
||||
return (
|
||||
<Flex
|
||||
className={DRAG_HANDLE_CLASSNAME}
|
||||
|
||||
@@ -24,6 +24,7 @@ import {
|
||||
selectNodesSlice,
|
||||
shouldAnimateEdgesChanged,
|
||||
shouldColorEdgesChanged,
|
||||
shouldShowEdgeLabelsChanged,
|
||||
shouldSnapToGridChanged,
|
||||
shouldValidateGraphChanged,
|
||||
} from 'features/nodes/store/nodesSlice';
|
||||
@@ -35,12 +36,20 @@ import { SelectionMode } from 'reactflow';
|
||||
const formLabelProps: FormLabelProps = { flexGrow: 1 };
|
||||
|
||||
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
const { shouldAnimateEdges, shouldValidateGraph, shouldSnapToGrid, shouldColorEdges, selectionMode } = nodes;
|
||||
const {
|
||||
shouldAnimateEdges,
|
||||
shouldValidateGraph,
|
||||
shouldSnapToGrid,
|
||||
shouldColorEdges,
|
||||
shouldShowEdgeLabels,
|
||||
selectionMode,
|
||||
} = nodes;
|
||||
return {
|
||||
shouldAnimateEdges,
|
||||
shouldValidateGraph,
|
||||
shouldSnapToGrid,
|
||||
shouldColorEdges,
|
||||
shouldShowEdgeLabels,
|
||||
selectionModeIsChecked: selectionMode === SelectionMode.Full,
|
||||
};
|
||||
});
|
||||
@@ -52,8 +61,14 @@ type Props = {
|
||||
const WorkflowEditorSettings = ({ children }: Props) => {
|
||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||
const dispatch = useAppDispatch();
|
||||
const { shouldAnimateEdges, shouldValidateGraph, shouldSnapToGrid, shouldColorEdges, selectionModeIsChecked } =
|
||||
useAppSelector(selector);
|
||||
const {
|
||||
shouldAnimateEdges,
|
||||
shouldValidateGraph,
|
||||
shouldSnapToGrid,
|
||||
shouldColorEdges,
|
||||
shouldShowEdgeLabels,
|
||||
selectionModeIsChecked,
|
||||
} = useAppSelector(selector);
|
||||
|
||||
const handleChangeShouldValidate = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => {
|
||||
@@ -90,6 +105,13 @@ const WorkflowEditorSettings = ({ children }: Props) => {
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handleChangeShouldShowEdgeLabels = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => {
|
||||
dispatch(shouldShowEdgeLabelsChanged(e.target.checked));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
@@ -137,6 +159,14 @@ const WorkflowEditorSettings = ({ children }: Props) => {
|
||||
<FormHelperText>{t('nodes.fullyContainNodesHelp')}</FormHelperText>
|
||||
</FormControl>
|
||||
<Divider />
|
||||
<FormControl>
|
||||
<Flex w="full">
|
||||
<FormLabel>{t('nodes.showEdgeLabels')}</FormLabel>
|
||||
<Switch isChecked={shouldShowEdgeLabels} onChange={handleChangeShouldShowEdgeLabels} />
|
||||
</Flex>
|
||||
<FormHelperText>{t('nodes.showEdgeLabelsHelp')}</FormHelperText>
|
||||
</FormControl>
|
||||
<Divider />
|
||||
<Heading size="sm" pt={4}>
|
||||
{t('common.advanced')}
|
||||
</Heading>
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodeTemplate } from 'features/nodes/store/selectors';
|
||||
@@ -10,7 +10,7 @@ import { useMemo } from 'react';
|
||||
export const useOutputFieldNames = (nodeId: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(selectNodesSlice, (nodes) => {
|
||||
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
const template = selectNodeTemplate(nodes, nodeId);
|
||||
if (!template) {
|
||||
return EMPTY_ARRAY;
|
||||
|
||||
@@ -5,8 +5,7 @@ import { useHasImageOutput } from './useHasImageOutput';
|
||||
|
||||
export const useWithFooter = (nodeId: string) => {
|
||||
const hasImageOutput = useHasImageOutput(nodeId);
|
||||
const isCacheEnabled = useFeatureStatus('invocationCache').isFeatureEnabled;
|
||||
|
||||
const isCacheEnabled = useFeatureStatus('invocationCache');
|
||||
const withFooter = useMemo(() => hasImageOutput || isCacheEnabled, [hasImageOutput, isCacheEnabled]);
|
||||
return withFooter;
|
||||
};
|
||||
|
||||
@@ -103,6 +103,7 @@ const initialNodesState: NodesState = {
|
||||
shouldAnimateEdges: true,
|
||||
shouldSnapToGrid: false,
|
||||
shouldColorEdges: true,
|
||||
shouldShowEdgeLabels: false,
|
||||
isAddNodePopoverOpen: false,
|
||||
nodeOpacity: 1,
|
||||
selectedNodes: [],
|
||||
@@ -549,6 +550,9 @@ export const nodesSlice = createSlice({
|
||||
shouldAnimateEdgesChanged: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldAnimateEdges = action.payload;
|
||||
},
|
||||
shouldShowEdgeLabelsChanged: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldShowEdgeLabels = action.payload;
|
||||
},
|
||||
shouldSnapToGridChanged: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldSnapToGrid = action.payload;
|
||||
},
|
||||
@@ -831,6 +835,7 @@ export const {
|
||||
viewportChanged,
|
||||
edgeAdded,
|
||||
nodeTemplatesBuilt,
|
||||
shouldShowEdgeLabelsChanged,
|
||||
} = nodesSlice.actions;
|
||||
|
||||
// This is used for tracking `state.workflow.isTouched`
|
||||
|
||||
@@ -32,6 +32,7 @@ export type NodesState = {
|
||||
isAddNodePopoverOpen: boolean;
|
||||
addNewNodePosition: XYPosition | null;
|
||||
selectionMode: SelectionMode;
|
||||
shouldShowEdgeLabels: boolean;
|
||||
};
|
||||
|
||||
export type WorkflowMode = 'edit' | 'view';
|
||||
|
||||
@@ -109,6 +109,7 @@ export const zIPAdapterField = z.object({
|
||||
image: zImageField,
|
||||
ip_adapter_model: zModelIdentifierField,
|
||||
weight: z.number(),
|
||||
method: z.enum(['full', 'style', 'composition']),
|
||||
begin_step_percent: z.number().optional(),
|
||||
end_step_percent: z.number().optional(),
|
||||
});
|
||||
|
||||
@@ -48,7 +48,7 @@ export const addIPAdapterToLinearGraph = async (
|
||||
if (!ipAdapter.model) {
|
||||
return;
|
||||
}
|
||||
const { id, weight, model, clipVisionModel, beginStepPct, endStepPct, controlImage } = ipAdapter;
|
||||
const { id, weight, model, clipVisionModel, method, beginStepPct, endStepPct, controlImage } = ipAdapter;
|
||||
|
||||
assert(controlImage, 'IP Adapter image is required');
|
||||
|
||||
@@ -57,6 +57,7 @@ export const addIPAdapterToLinearGraph = async (
|
||||
type: 'ip_adapter',
|
||||
is_intermediate: true,
|
||||
weight: weight,
|
||||
method: method,
|
||||
ip_adapter_model: model,
|
||||
clip_vision_model: clipVisionModel,
|
||||
begin_step_percent: beginStepPct,
|
||||
@@ -84,7 +85,7 @@ export const addIPAdapterToLinearGraph = async (
|
||||
};
|
||||
|
||||
const buildIPAdapterMetadata = (ipAdapter: IPAdapterConfig): S['IPAdapterMetadataField'] => {
|
||||
const { controlImage, beginStepPct, endStepPct, model, clipVisionModel, weight } = ipAdapter;
|
||||
const { controlImage, beginStepPct, endStepPct, model, clipVisionModel, method, weight } = ipAdapter;
|
||||
|
||||
assert(model, 'IP Adapter model is required');
|
||||
|
||||
@@ -102,6 +103,7 @@ const buildIPAdapterMetadata = (ipAdapter: IPAdapterConfig): S['IPAdapterMetadat
|
||||
ip_adapter_model: model,
|
||||
clip_vision_model: clipVisionModel,
|
||||
weight,
|
||||
method,
|
||||
begin_step_percent: beginStepPct,
|
||||
end_step_percent: endStepPct,
|
||||
image,
|
||||
|
||||
@@ -1,24 +1,18 @@
|
||||
import { Box, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIColorPicker from 'common/components/IAIColorPicker';
|
||||
import { selectGenerationSlice, setInfillColorValue } from 'features/parameters/store/generationSlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import type { RgbaColor } from 'react-colorful';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const selectInfillColor = createMemoizedSelector(selectGenerationSlice, (generation) => generation.infillColorValue);
|
||||
|
||||
const ParamInfillColorOptions = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(selectGenerationSlice, (generation) => ({
|
||||
infillColor: generation.infillColorValue,
|
||||
})),
|
||||
[]
|
||||
);
|
||||
|
||||
const { infillColor } = useAppSelector(selector);
|
||||
const infillColor = useAppSelector(selectInfillColor);
|
||||
|
||||
const infillMethod = useAppSelector((s) => s.generation.infillMethod);
|
||||
|
||||
|
||||
@@ -1,35 +1,23 @@
|
||||
import { Box, CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIColorPicker from 'common/components/IAIColorPicker';
|
||||
import {
|
||||
selectGenerationSlice,
|
||||
setInfillMosaicMaxColor,
|
||||
setInfillMosaicMinColor,
|
||||
setInfillMosaicTileHeight,
|
||||
setInfillMosaicTileWidth,
|
||||
} from 'features/parameters/store/generationSlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import type { RgbaColor } from 'react-colorful';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const ParamInfillMosaicTileSize = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(selectGenerationSlice, (generation) => ({
|
||||
infillMosaicTileWidth: generation.infillMosaicTileWidth,
|
||||
infillMosaicTileHeight: generation.infillMosaicTileHeight,
|
||||
infillMosaicMinColor: generation.infillMosaicMinColor,
|
||||
infillMosaicMaxColor: generation.infillMosaicMaxColor,
|
||||
})),
|
||||
[]
|
||||
);
|
||||
|
||||
const { infillMosaicTileWidth, infillMosaicTileHeight, infillMosaicMinColor, infillMosaicMaxColor } =
|
||||
useAppSelector(selector);
|
||||
|
||||
const infillMosaicTileWidth = useAppSelector((s) => s.generation.infillMosaicTileWidth);
|
||||
const infillMosaicTileHeight = useAppSelector((s) => s.generation.infillMosaicTileHeight);
|
||||
const infillMosaicMinColor = useAppSelector((s) => s.generation.infillMosaicMinColor);
|
||||
const infillMosaicMaxColor = useAppSelector((s) => s.generation.infillMosaicMaxColor);
|
||||
const infillMethod = useAppSelector((s) => s.generation.infillMethod);
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
@@ -27,8 +27,8 @@ export const QueueActionsMenuButton = memo(() => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const clearQueueDisclosure = useDisclosure();
|
||||
const isPauseEnabled = useFeatureStatus('pauseQueue').isFeatureEnabled;
|
||||
const isResumeEnabled = useFeatureStatus('resumeQueue').isFeatureEnabled;
|
||||
const isPauseEnabled = useFeatureStatus('pauseQueue');
|
||||
const isResumeEnabled = useFeatureStatus('resumeQueue');
|
||||
const { queueSize } = useGetQueueStatusQuery(undefined, {
|
||||
selectFromResult: (res) => ({
|
||||
queueSize: res.data ? res.data.queue.pending + res.data.queue.in_progress : 0,
|
||||
|
||||
@@ -9,7 +9,7 @@ import { InvokeQueueBackButton } from './InvokeQueueBackButton';
|
||||
import { QueueActionsMenuButton } from './QueueActionsMenuButton';
|
||||
|
||||
const QueueControls = () => {
|
||||
const isPrependEnabled = useFeatureStatus('prependQueue').isFeatureEnabled;
|
||||
const isPrependEnabled = useFeatureStatus('prependQueue');
|
||||
return (
|
||||
<Flex w="full" position="relative" borderRadius="base" gap={2} pt={2} flexDir="column">
|
||||
<ButtonGroup size="lg" isAttached={false}>
|
||||
|
||||
@@ -8,7 +8,7 @@ import QueueStatus from './QueueStatus';
|
||||
import QueueTabQueueControls from './QueueTabQueueControls';
|
||||
|
||||
const QueueTabContent = () => {
|
||||
const isInvocationCacheEnabled = useFeatureStatus('invocationCache').isFeatureEnabled;
|
||||
const isInvocationCacheEnabled = useFeatureStatus('invocationCache');
|
||||
|
||||
return (
|
||||
<Flex borderRadius="base" w="full" h="full" flexDir="column" gap={2}>
|
||||
|
||||
@@ -8,8 +8,8 @@ import PruneQueueButton from './PruneQueueButton';
|
||||
import ResumeProcessorButton from './ResumeProcessorButton';
|
||||
|
||||
const QueueTabQueueControls = () => {
|
||||
const isPauseEnabled = useFeatureStatus('pauseQueue').isFeatureEnabled;
|
||||
const isResumeEnabled = useFeatureStatus('resumeQueue').isFeatureEnabled;
|
||||
const isPauseEnabled = useFeatureStatus('pauseQueue');
|
||||
const isResumeEnabled = useFeatureStatus('resumeQueue');
|
||||
return (
|
||||
<Flex layerStyle="first" borderRadius="base" p={2} gap={2}>
|
||||
{isPauseEnabled || isResumeEnabled ? (
|
||||
|
||||
@@ -13,7 +13,7 @@ export const useQueueFront = () => {
|
||||
const [_, { isLoading }] = useEnqueueBatchMutation({
|
||||
fixedCacheKey: 'enqueueBatch',
|
||||
});
|
||||
const prependEnabled = useFeatureStatus('prependQueue').isFeatureEnabled;
|
||||
const prependEnabled = useFeatureStatus('prependQueue');
|
||||
|
||||
const isDisabled = useMemo(() => {
|
||||
return !isReady || !prependEnabled;
|
||||
|
||||
@@ -62,7 +62,7 @@ const selector = createMemoizedSelector(selectControlAdaptersSlice, (controlAdap
|
||||
export const ControlSettingsAccordion: React.FC = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const { controlAdapterIds, badges } = useAppSelector(selector);
|
||||
const isControlNetDisabled = useFeatureStatus('controlNet').isFeatureDisabled;
|
||||
const isControlNetEnabled = useFeatureStatus('controlNet');
|
||||
const { isOpen, onToggle } = useStandaloneAccordionToggle({
|
||||
id: 'control-settings',
|
||||
defaultIsOpen: true,
|
||||
@@ -71,7 +71,7 @@ export const ControlSettingsAccordion: React.FC = memo(() => {
|
||||
const [addIPAdapter, isAddIPAdapterDisabled] = useAddControlAdapter('ip_adapter');
|
||||
const [addT2IAdapter, isAddT2IAdapterDisabled] = useAddControlAdapter('t2i_adapter');
|
||||
|
||||
if (isControlNetDisabled) {
|
||||
if (!isControlNetEnabled) {
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ export const SettingsLanguageSelect = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const language = useAppSelector((s) => s.system.language);
|
||||
const isLocalizationEnabled = useFeatureStatus('localization').isFeatureEnabled;
|
||||
const isLocalizationEnabled = useFeatureStatus('localization');
|
||||
|
||||
const value = useMemo(() => options.find((o) => o.value === language), [language]);
|
||||
|
||||
|
||||
@@ -23,9 +23,9 @@ const SettingsMenu = () => {
|
||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||
useGlobalMenuClose(onClose);
|
||||
|
||||
const isBugLinkEnabled = useFeatureStatus('bugLink').isFeatureEnabled;
|
||||
const isDiscordLinkEnabled = useFeatureStatus('discordLink').isFeatureEnabled;
|
||||
const isGithubLinkEnabled = useFeatureStatus('githubLink').isFeatureEnabled;
|
||||
const isBugLinkEnabled = useFeatureStatus('bugLink');
|
||||
const isDiscordLinkEnabled = useFeatureStatus('discordLink');
|
||||
const isGithubLinkEnabled = useFeatureStatus('githubLink');
|
||||
|
||||
return (
|
||||
<Menu isOpen={isOpen} onOpen={onOpen} onClose={onClose}>
|
||||
|
||||
@@ -1,32 +1,24 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import type { AppFeature, SDFeature } from 'app/types/invokeai';
|
||||
import { selectConfigSlice } from 'features/system/store/configSlice';
|
||||
import type { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useFeatureStatus = (feature: AppFeature | SDFeature | InvokeTabName) => {
|
||||
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
|
||||
|
||||
const disabledFeatures = useAppSelector((s) => s.config.disabledFeatures);
|
||||
|
||||
const disabledSDFeatures = useAppSelector((s) => s.config.disabledSDFeatures);
|
||||
|
||||
const isFeatureDisabled = useMemo(
|
||||
const selectIsFeatureEnabled = useMemo(
|
||||
() =>
|
||||
disabledFeatures.includes(feature as AppFeature) ||
|
||||
disabledSDFeatures.includes(feature as SDFeature) ||
|
||||
disabledTabs.includes(feature as InvokeTabName),
|
||||
[disabledFeatures, disabledSDFeatures, disabledTabs, feature]
|
||||
createSelector(selectConfigSlice, (config) => {
|
||||
return !(
|
||||
config.disabledFeatures.includes(feature as AppFeature) ||
|
||||
config.disabledSDFeatures.includes(feature as SDFeature) ||
|
||||
config.disabledTabs.includes(feature as InvokeTabName)
|
||||
);
|
||||
}),
|
||||
[feature]
|
||||
);
|
||||
|
||||
const isFeatureEnabled = useMemo(
|
||||
() =>
|
||||
!(
|
||||
disabledFeatures.includes(feature as AppFeature) ||
|
||||
disabledSDFeatures.includes(feature as SDFeature) ||
|
||||
disabledTabs.includes(feature as InvokeTabName)
|
||||
),
|
||||
[disabledFeatures, disabledSDFeatures, disabledTabs, feature]
|
||||
);
|
||||
const isFeatureEnabled = useAppSelector(selectIsFeatureEnabled);
|
||||
|
||||
return { isFeatureDisabled, isFeatureEnabled };
|
||||
return isFeatureEnabled;
|
||||
};
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -1 +1 @@
|
||||
__version__ = "4.0.4"
|
||||
__version__ = "4.1.0"
|
||||
|
||||
132
tests/backend/util/test_devices.py
Normal file
132
tests/backend/util/test_devices.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""
|
||||
Test abstract device class.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.app.services.config import get_config
|
||||
from invokeai.backend.util.devices import TorchDevice, choose_precision, choose_torch_device, torch_dtype
|
||||
|
||||
devices = ["cpu", "cuda:0", "cuda:1", "mps"]
|
||||
device_types_cpu = [("cpu", torch.float32), ("cuda:0", torch.float32), ("mps", torch.float32)]
|
||||
device_types_cuda = [("cpu", torch.float32), ("cuda:0", torch.float16), ("mps", torch.float32)]
|
||||
device_types_mps = [("cpu", torch.float32), ("cuda:0", torch.float32), ("mps", torch.float16)]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device_name", devices)
|
||||
def test_device_choice(device_name):
|
||||
config = get_config()
|
||||
config.device = device_name
|
||||
torch_device = TorchDevice.choose_torch_device()
|
||||
assert torch_device == torch.device(device_name)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device_dtype_pair", device_types_cpu)
|
||||
def test_device_dtype_cpu(device_dtype_pair):
|
||||
with (
|
||||
patch("torch.cuda.is_available", return_value=False),
|
||||
patch("torch.backends.mps.is_available", return_value=False),
|
||||
):
|
||||
device_name, dtype = device_dtype_pair
|
||||
config = get_config()
|
||||
config.device = device_name
|
||||
torch_dtype = TorchDevice.choose_torch_dtype()
|
||||
assert torch_dtype == dtype
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device_dtype_pair", device_types_cuda)
|
||||
def test_device_dtype_cuda(device_dtype_pair):
|
||||
with (
|
||||
patch("torch.cuda.is_available", return_value=True),
|
||||
patch("torch.cuda.get_device_name", return_value="RTX4070"),
|
||||
patch("torch.backends.mps.is_available", return_value=False),
|
||||
):
|
||||
device_name, dtype = device_dtype_pair
|
||||
config = get_config()
|
||||
config.device = device_name
|
||||
torch_dtype = TorchDevice.choose_torch_dtype()
|
||||
assert torch_dtype == dtype
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device_dtype_pair", device_types_mps)
|
||||
def test_device_dtype_mps(device_dtype_pair):
|
||||
with (
|
||||
patch("torch.cuda.is_available", return_value=False),
|
||||
patch("torch.backends.mps.is_available", return_value=True),
|
||||
):
|
||||
device_name, dtype = device_dtype_pair
|
||||
config = get_config()
|
||||
config.device = device_name
|
||||
torch_dtype = TorchDevice.choose_torch_dtype()
|
||||
assert torch_dtype == dtype
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device_dtype_pair", device_types_cuda)
|
||||
def test_device_dtype_override(device_dtype_pair):
|
||||
with (
|
||||
patch("torch.cuda.get_device_name", return_value="RTX4070"),
|
||||
patch("torch.cuda.is_available", return_value=True),
|
||||
patch("torch.backends.mps.is_available", return_value=False),
|
||||
):
|
||||
device_name, dtype = device_dtype_pair
|
||||
config = get_config()
|
||||
config.device = device_name
|
||||
config.precision = "float32"
|
||||
torch_dtype = TorchDevice.choose_torch_dtype()
|
||||
assert torch_dtype == torch.float32
|
||||
|
||||
|
||||
def test_normalize():
|
||||
assert (
|
||||
TorchDevice.normalize("cuda") == torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cuda")
|
||||
)
|
||||
assert (
|
||||
TorchDevice.normalize("cuda:0") == torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cuda")
|
||||
)
|
||||
assert (
|
||||
TorchDevice.normalize("cuda:1") == torch.device("cuda:1") if torch.cuda.is_available() else torch.device("cuda")
|
||||
)
|
||||
assert TorchDevice.normalize("mps") == torch.device("mps")
|
||||
assert TorchDevice.normalize("cpu") == torch.device("cpu")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device_name", devices)
|
||||
def test_legacy_device_choice(device_name):
|
||||
config = get_config()
|
||||
config.device = device_name
|
||||
with pytest.deprecated_call():
|
||||
torch_device = choose_torch_device()
|
||||
assert torch_device == torch.device(device_name)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device_dtype_pair", device_types_cpu)
|
||||
def test_legacy_device_dtype_cpu(device_dtype_pair):
|
||||
with (
|
||||
patch("torch.cuda.is_available", return_value=False),
|
||||
patch("torch.backends.mps.is_available", return_value=False),
|
||||
patch("torch.cuda.get_device_name", return_value="RTX9090"),
|
||||
):
|
||||
device_name, dtype = device_dtype_pair
|
||||
config = get_config()
|
||||
config.device = device_name
|
||||
with pytest.deprecated_call():
|
||||
torch_device = choose_torch_device()
|
||||
returned_dtype = torch_dtype(torch_device)
|
||||
assert returned_dtype == dtype
|
||||
|
||||
|
||||
def test_legacy_precision_name():
|
||||
config = get_config()
|
||||
config.precision = "auto"
|
||||
with (
|
||||
pytest.deprecated_call(),
|
||||
patch("torch.cuda.is_available", return_value=True),
|
||||
patch("torch.backends.mps.is_available", return_value=True),
|
||||
patch("torch.cuda.get_device_name", return_value="RTX9090"),
|
||||
):
|
||||
assert "float16" == choose_precision(torch.device("cuda"))
|
||||
assert "float16" == choose_precision(torch.device("mps"))
|
||||
assert "float32" == choose_precision(torch.device("cpu"))
|
||||
Reference in New Issue
Block a user